Skip to content

Commit

Permalink
Introduce DatabricksEnvironment and fix Azure MSI auth from ACR, wh…
Browse files Browse the repository at this point in the history
…ere IMDS doesn't give host environment information (#700)

## Changes
This PR allows determining Azure Environment from a Databricks account
or workspace hostname, removing the need for a separate
configuration/environment variable and complexities related to Azure MSI
from within ACR.

Similar functionality in Python SDK: 
- databricks/databricks-sdk-py#390

Stacked on top of:
- #699

## Tests
<!-- 
How is this tested? Please see the checklist below and also describe any
other relevant tests
-->

- [x] `make test` passing
- [x] `make fmt` applied
- [ ] relevant integration tests applied
  • Loading branch information
nfx authored Nov 28, 2023
1 parent cc2b104 commit 832d554
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 137 deletions.
10 changes: 3 additions & 7 deletions config/auth_azure_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (c AzureCliCredentials) Name() string {

// implementing azureHostResolver for ensureWorkspaceUrl to work
func (c AzureCliCredentials) tokenSourceFor(
ctx context.Context, cfg *Config, env azureEnvironment, resource string) oauth2.TokenSource {
ctx context.Context, cfg *Config, _, resource string) oauth2.TokenSource {
return &azureCliTokenSource{resource: resource}
}

Expand All @@ -44,11 +44,7 @@ func (c AzureCliCredentials) tokenSourceFor(
// If the user can't access the service management endpoint, we assume they are in case 2 and do not pass the service
// management token. Otherwise, we always pass the service management token.
func (c AzureCliCredentials) getVisitor(ctx context.Context, cfg *Config, inner oauth2.TokenSource) (func(*http.Request) error, error) {
env, err := cfg.GetAzureEnvironment()
if err != nil {
return nil, err
}
ts := &azureCliTokenSource{env.ServiceManagementEndpoint, ""}
ts := &azureCliTokenSource{cfg.Environment().AzureServiceManagementEndpoint(), ""}
t, err := ts.Token()
if err != nil {
logger.Debugf(ctx, "Not including service management token in headers: %v", err)
Expand All @@ -63,7 +59,7 @@ func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (func(*
return nil, nil
}
// Eagerly get a token to fail fast in case the user is not logged in with the Azure CLI.
ts := &azureCliTokenSource{cfg.getAzureLoginAppID(), cfg.AzureResourceID}
ts := &azureCliTokenSource{cfg.Environment().azureApplicationID, cfg.AzureResourceID}
t, err := ts.Token()
if err != nil {
if strings.Contains(err.Error(), "No subscription found") {
Expand Down
17 changes: 8 additions & 9 deletions config/auth_azure_client_secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ func (c AzureClientSecretCredentials) Name() string {
}

func (c AzureClientSecretCredentials) tokenSourceFor(
ctx context.Context, cfg *Config, env azureEnvironment, resource string) oauth2.TokenSource {
ctx context.Context, cfg *Config, aadEndpoint, resource string) oauth2.TokenSource {
return (&clientcredentials.Config{
ClientID: cfg.AzureClientID,
ClientSecret: cfg.AzureClientSecret,
TokenURL: fmt.Sprintf("%s%s/oauth2/token", env.ActiveDirectoryEndpoint, cfg.AzureTenantID),
TokenURL: fmt.Sprintf("%s%s/oauth2/token", aadEndpoint, cfg.AzureTenantID),
EndpointParams: url.Values{
"resource": []string{resource},
},
Expand All @@ -43,18 +43,17 @@ func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config
if !cfg.IsAzure() {
return nil, nil
}
env, err := cfg.GetAzureEnvironment()
if err != nil {
return nil, err
}
ctx = httpclient.DefaultClient.InContextForOAuth2(ctx)
err = cfg.azureEnsureWorkspaceUrl(ctx, c)
err := cfg.azureEnsureWorkspaceUrl(ctx, c)
if err != nil {
return nil, fmt.Errorf("resolve host: %w", err)
}
logger.Infof(ctx, "Generating AAD token for Service Principal (%s)", cfg.AzureClientID)
refreshCtx := context.Background()
inner := azureReuseTokenSource(nil, c.tokenSourceFor(refreshCtx, cfg, env, cfg.getAzureLoginAppID()))
management := azureReuseTokenSource(nil, c.tokenSourceFor(refreshCtx, cfg, env, env.ServiceManagementEndpoint))
env := cfg.Environment()
aadEndpoint := env.AzureActiveDirectoryEndpoint()
managementEndpoint := env.AzureServiceManagementEndpoint()
inner := azureReuseTokenSource(nil, c.tokenSourceFor(refreshCtx, cfg, aadEndpoint, env.azureApplicationID))
management := azureReuseTokenSource(nil, c.tokenSourceFor(refreshCtx, cfg, aadEndpoint, managementEndpoint))
return azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken)), nil
}
52 changes: 5 additions & 47 deletions config/auth_azure_msi.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,70 +31,28 @@ func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (func(*
if !cfg.IsAzure() || !cfg.AzureUseMSI || (cfg.AzureResourceID == "" && !cfg.IsAccountClient()) {
return nil, nil
}
env, err := c.getInstanceEnvironment(ctx)
if err != nil {
return nil, err
}
env := cfg.Environment()
ctx = httpclient.DefaultClient.InContextForOAuth2(ctx)
if !cfg.IsAccountClient() {
err = cfg.azureEnsureWorkspaceUrl(ctx, c)
err := cfg.azureEnsureWorkspaceUrl(ctx, c)
if err != nil {
return nil, fmt.Errorf("resolve host: %w", err)
}
}
logger.Debugf(ctx, "Generating AAD token via Azure MSI")
inner := azureReuseTokenSource(nil, azureMsiTokenSource{
resource: cfg.getAzureLoginAppID(),
resource: env.azureApplicationID,
clientId: cfg.AzureClientID,
})
management := azureReuseTokenSource(nil, azureMsiTokenSource{
resource: env.ServiceManagementEndpoint,
resource: env.AzureServiceManagementEndpoint(),
clientId: cfg.AzureClientID,
})
return azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken)), nil
}

func (c AzureMsiCredentials) getInstanceEnvironment(ctx context.Context) (*azureEnvironment, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
fmt.Sprintf("%s/instance", instanceMetadataPrefix), nil)
if err != nil {
return nil, fmt.Errorf("metadata request: %w", err)
}
query := req.URL.Query()
query.Add("api-version", "2021-12-13")
req.URL.RawQuery = query.Encode()
req.Header.Add("Metadata", "true")
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("metadata response: %w", err)
}
raw, err := io.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("metadata read: %w", err)
}
if res.StatusCode != 200 {
return nil, fmt.Errorf("metadata error: %s", raw)
}
var metadata struct {
Compute struct {
Environment string `json:"azEnvironment"`
} `json:"compute"`
}
err = json.Unmarshal(raw, &metadata)
if err != nil {
return nil, fmt.Errorf("metadata parse: %w", err)
}
for _, v := range azureEnvironments {
if v.Name == metadata.Compute.Environment {
return &v, nil
}
}
return nil, fmt.Errorf("cannot determine environment: %s",
metadata.Compute.Environment)
}

// implementing azureHostResolver for ensureWorkspaceUrl to work
func (c AzureMsiCredentials) tokenSourceFor(_ context.Context, cfg *Config, _ azureEnvironment, resource string) oauth2.TokenSource {
func (c AzureMsiCredentials) tokenSourceFor(_ context.Context, cfg *Config, _, resource string) oauth2.TokenSource {
return azureMsiTokenSource{
resource: resource,
clientId: cfg.AzureClientID,
Expand Down
8 changes: 8 additions & 0 deletions config/auth_permutations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ func (cf configFixture) configureProviderAndReturnConfig(t *testing.T) (*Config,
AzureResourceID: cf.AzureResourceID,
AuthType: cf.AuthType,
}
if client.IsAzure() {
client.DatabricksEnvironments = append(client.DatabricksEnvironments, DatabricksEnvironment{
Cloud: CloudAzure,
dnsZone: cf.Host,
azureApplicationID: "abc",
azureEnvironment: &publicCloud,
})
}
err := client.Authenticate(&http.Request{Header: http.Header{}})
if err != nil {
return nil, err
Expand Down
55 changes: 8 additions & 47 deletions config/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package config
import (
"context"
"fmt"
"strings"

"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
Expand All @@ -20,73 +19,45 @@ type azureEnvironment struct {
// based on github.com/Azure/go-autorest/autorest/azure/azureEnvironments.go
var (
publicCloud = azureEnvironment{
Name: "AzurePublicCloud",
Name: "PUBLIC",
ServiceManagementEndpoint: "https://management.core.windows.net/",
ResourceManagerEndpoint: "https://management.azure.com/",
ActiveDirectoryEndpoint: "https://login.microsoftonline.com/",
}

usGovernmentCloud = azureEnvironment{
Name: "AzureUSGovernmentCloud",
Name: "USGOVERNMENT",
ServiceManagementEndpoint: "https://management.core.usgovcloudapi.net/",
ResourceManagerEndpoint: "https://management.usgovcloudapi.net/",
ActiveDirectoryEndpoint: "https://login.microsoftonline.us/",
}

chinaCloud = azureEnvironment{
Name: "AzureChinaCloud",
Name: "CHINA",
ServiceManagementEndpoint: "https://management.core.chinacloudapi.cn/",
ResourceManagerEndpoint: "https://management.chinacloudapi.cn/",
ActiveDirectoryEndpoint: "https://login.chinacloudapi.cn/",
}

germanCloud = azureEnvironment{
Name: "AzureGermanCloud",
ServiceManagementEndpoint: "https://management.core.cloudapi.de/",
ResourceManagerEndpoint: "https://management.microsoftazure.de/",
ActiveDirectoryEndpoint: "https://login.microsoftonline.de/",
}

azureEnvironments = map[string]azureEnvironment{
"CHINA": chinaCloud,
"GERMAN": germanCloud,
"PUBLIC": publicCloud,
"USGOVERNMENT": usGovernmentCloud,
}
)

func (c *Config) GetAzureEnvironment() (azureEnvironment, error) {
if c.AzureEnvironment == "" {
c.AzureEnvironment = "public"
}
env, ok := azureEnvironments[strings.ToUpper(c.AzureEnvironment)]
if !ok {
return env, fmt.Errorf("azure environment not found: %s", c.AzureEnvironment)
}
return env, nil
}

type azureHostResolver interface {
tokenSourceFor(ctx context.Context, cfg *Config, env azureEnvironment, resource string) oauth2.TokenSource
tokenSourceFor(ctx context.Context, cfg *Config, aadEndpoint, resource string) oauth2.TokenSource
}

func (c *Config) azureEnsureWorkspaceUrl(ctx context.Context, ahr azureHostResolver) error {
if c.AzureResourceID == "" || c.Host != "" {
return nil
}
env, err := c.GetAzureEnvironment()
if err != nil {
return err
}
azureEnv := c.Environment().azureEnvironment
// azure resource ID can also be used in lieu of host by some of the clients, like Terraform
management := ahr.tokenSourceFor(ctx, c, env, env.ResourceManagerEndpoint)
management := ahr.tokenSourceFor(ctx, c, azureEnv.ActiveDirectoryEndpoint, azureEnv.ResourceManagerEndpoint)
var workspaceMetadata struct {
Properties struct {
WorkspaceURL string `json:"workspaceUrl"`
} `json:"properties"`
}
requestURL := env.ResourceManagerEndpoint + c.AzureResourceID + "?api-version=2018-04-01"
err = httpclient.DefaultClient.Do(ctx, "GET", requestURL,
requestURL := azureEnv.ResourceManagerEndpoint + c.AzureResourceID + "?api-version=2018-04-01"
err := httpclient.DefaultClient.Do(ctx, "GET", requestURL,
httpclient.WithResponseUnmarshal(&workspaceMetadata),
httpclient.WithTokenSource(management),
)
Expand All @@ -97,13 +68,3 @@ func (c *Config) azureEnsureWorkspaceUrl(ctx context.Context, ahr azureHostResol
logger.Debugf(ctx, "Discovered workspace url: %s", c.Host)
return nil
}

// Resource ID of the Azure application we need to log in.
const azureDatabricksLoginAppID string = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d"

func (c *Config) getAzureLoginAppID() string {
if c.AzureLoginAppID != "" {
return c.AzureLoginAppID
}
return azureDatabricksLoginAppID
}
19 changes: 0 additions & 19 deletions config/azure_test.go

This file was deleted.

23 changes: 15 additions & 8 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,15 @@ type Config struct {
AzureClientID string `name:"azure_client_id" env:"ARM_CLIENT_ID" auth:"azure"`
AzureTenantID string `name:"azure_tenant_id" env:"ARM_TENANT_ID" auth:"azure"`

// AzureEnvironment (Public, UsGov, China, Germany) has specific set of API endpoints.
// AzureEnvironment (PUBLIC, USGOVERNMENT, CHINA) has specific set of API endpoints. Starting from v0.26.0,
// the environment is determined based on the workspace hostname, if it's specified.
AzureEnvironment string `name:"azure_environment" env:"ARM_ENVIRONMENT"`

// Azure Login Application ID. Must be set if authenticating for non-production workspaces.
// Azure Login Application ID. Must be set if authenticating for non-production workspaces. Starting from v0.26.0,
// the correct Azure Login App ID is determined based on the Azure Databricks Workspace hostname.
//
// Deprecated: this configuration property no longer has any effect and will be removed in the future
// versions of Go SDK.
AzureLoginAppID string `name:"azure_login_app_id" env:"DATABRICKS_AZURE_LOGIN_APP_ID" auth:"azure"`

ClientID string `name:"client_id" env:"DATABRICKS_CLIENT_ID" auth:"oauth"`
Expand Down Expand Up @@ -108,6 +113,9 @@ type Config struct {
// HTTPTransport can be overriden for unit testing and together with tooling like https://github.com/google/go-replayers
HTTPTransport http.RoundTripper

// Metadata about the environment, where Databricks is deployed. Reserved for internal use.
DatabricksEnvironments []DatabricksEnvironment

Loaders []Loader

// marker for configuration resolving
Expand Down Expand Up @@ -142,16 +150,15 @@ func (c *Config) Authenticate(r *http.Request) error {

// IsAzure returns if the client is configured for Azure Databricks.
func (c *Config) IsAzure() bool {
isAzureHost := strings.Contains(c.Host, ".azuredatabricks.net") ||
strings.Contains(c.Host, "databricks.azure.cn") ||
strings.Contains(c.Host, ".databricks.azure.us")

return isAzureHost || c.AzureResourceID != ""
if c.AzureResourceID != "" {
return true
}
return c.Environment().Cloud == CloudAzure
}

// IsGcp returns if the client is configured for Databricks on Google Cloud.
func (c *Config) IsGcp() bool {
return strings.Contains(c.Host, ".gcp.databricks.com")
return c.Environment().Cloud == CloudGCP
}

// IsAws returns if the client is configured for Databricks on AWS.
Expand Down
Loading

0 comments on commit 832d554

Please sign in to comment.