Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce DatabricksEnvironment and fix Azure MSI auth from ACR, where IMDS doesn't give host environment information #700

Merged
merged 1 commit into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions config/auth_azure_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (c AzureCliCredentials) Name() string {

// implementing azureHostResolver for ensureWorkspaceUrl to work
func (c AzureCliCredentials) tokenSourceFor(
ctx context.Context, cfg *Config, env azureEnvironment, resource string) oauth2.TokenSource {
ctx context.Context, cfg *Config, _, resource string) oauth2.TokenSource {
return &azureCliTokenSource{resource: resource}
}

Expand All @@ -44,11 +44,7 @@ func (c AzureCliCredentials) tokenSourceFor(
// If the user can't access the service management endpoint, we assume they are in case 2 and do not pass the service
// management token. Otherwise, we always pass the service management token.
func (c AzureCliCredentials) getVisitor(ctx context.Context, cfg *Config, inner oauth2.TokenSource) (func(*http.Request) error, error) {
env, err := cfg.GetAzureEnvironment()
if err != nil {
return nil, err
}
ts := &azureCliTokenSource{env.ServiceManagementEndpoint, ""}
ts := &azureCliTokenSource{cfg.Environment().AzureServiceManagementEndpoint(), ""}
t, err := ts.Token()
if err != nil {
logger.Debugf(ctx, "Not including service management token in headers: %v", err)
Expand All @@ -63,7 +59,7 @@ func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (func(*
return nil, nil
}
// Eagerly get a token to fail fast in case the user is not logged in with the Azure CLI.
ts := &azureCliTokenSource{cfg.getAzureLoginAppID(), cfg.AzureResourceID}
ts := &azureCliTokenSource{cfg.Environment().azureApplicationID, cfg.AzureResourceID}
t, err := ts.Token()
if err != nil {
if strings.Contains(err.Error(), "No subscription found") {
Expand Down
17 changes: 8 additions & 9 deletions config/auth_azure_client_secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ func (c AzureClientSecretCredentials) Name() string {
}

func (c AzureClientSecretCredentials) tokenSourceFor(
ctx context.Context, cfg *Config, env azureEnvironment, resource string) oauth2.TokenSource {
ctx context.Context, cfg *Config, aadEndpoint, resource string) oauth2.TokenSource {
return (&clientcredentials.Config{
ClientID: cfg.AzureClientID,
ClientSecret: cfg.AzureClientSecret,
TokenURL: fmt.Sprintf("%s%s/oauth2/token", env.ActiveDirectoryEndpoint, cfg.AzureTenantID),
TokenURL: fmt.Sprintf("%s%s/oauth2/token", aadEndpoint, cfg.AzureTenantID),
EndpointParams: url.Values{
"resource": []string{resource},
},
Expand All @@ -43,18 +43,17 @@ func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config
if !cfg.IsAzure() {
return nil, nil
}
env, err := cfg.GetAzureEnvironment()
if err != nil {
return nil, err
}
ctx = httpclient.DefaultClient.InContextForOAuth2(ctx)
err = cfg.azureEnsureWorkspaceUrl(ctx, c)
err := cfg.azureEnsureWorkspaceUrl(ctx, c)
if err != nil {
return nil, fmt.Errorf("resolve host: %w", err)
}
logger.Infof(ctx, "Generating AAD token for Service Principal (%s)", cfg.AzureClientID)
refreshCtx := context.Background()
inner := azureReuseTokenSource(nil, c.tokenSourceFor(refreshCtx, cfg, env, cfg.getAzureLoginAppID()))
management := azureReuseTokenSource(nil, c.tokenSourceFor(refreshCtx, cfg, env, env.ServiceManagementEndpoint))
env := cfg.Environment()
aadEndpoint := env.AzureActiveDirectoryEndpoint()
managementEndpoint := env.AzureServiceManagementEndpoint()
inner := azureReuseTokenSource(nil, c.tokenSourceFor(refreshCtx, cfg, aadEndpoint, env.azureApplicationID))
management := azureReuseTokenSource(nil, c.tokenSourceFor(refreshCtx, cfg, aadEndpoint, managementEndpoint))
return azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken)), nil
}
52 changes: 5 additions & 47 deletions config/auth_azure_msi.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,70 +31,28 @@ func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (func(*
if !cfg.IsAzure() || !cfg.AzureUseMSI || (cfg.AzureResourceID == "" && !cfg.IsAccountClient()) {
return nil, nil
}
env, err := c.getInstanceEnvironment(ctx)
if err != nil {
return nil, err
}
env := cfg.Environment()
ctx = httpclient.DefaultClient.InContextForOAuth2(ctx)
if !cfg.IsAccountClient() {
err = cfg.azureEnsureWorkspaceUrl(ctx, c)
err := cfg.azureEnsureWorkspaceUrl(ctx, c)
if err != nil {
return nil, fmt.Errorf("resolve host: %w", err)
}
}
logger.Debugf(ctx, "Generating AAD token via Azure MSI")
inner := azureReuseTokenSource(nil, azureMsiTokenSource{
resource: cfg.getAzureLoginAppID(),
resource: env.azureApplicationID,
clientId: cfg.AzureClientID,
})
management := azureReuseTokenSource(nil, azureMsiTokenSource{
resource: env.ServiceManagementEndpoint,
resource: env.AzureServiceManagementEndpoint(),
clientId: cfg.AzureClientID,
})
return azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken)), nil
}

func (c AzureMsiCredentials) getInstanceEnvironment(ctx context.Context) (*azureEnvironment, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
fmt.Sprintf("%s/instance", instanceMetadataPrefix), nil)
if err != nil {
return nil, fmt.Errorf("metadata request: %w", err)
}
query := req.URL.Query()
query.Add("api-version", "2021-12-13")
req.URL.RawQuery = query.Encode()
req.Header.Add("Metadata", "true")
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("metadata response: %w", err)
}
raw, err := io.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("metadata read: %w", err)
}
if res.StatusCode != 200 {
return nil, fmt.Errorf("metadata error: %s", raw)
}
var metadata struct {
Compute struct {
Environment string `json:"azEnvironment"`
} `json:"compute"`
}
err = json.Unmarshal(raw, &metadata)
if err != nil {
return nil, fmt.Errorf("metadata parse: %w", err)
}
for _, v := range azureEnvironments {
if v.Name == metadata.Compute.Environment {
return &v, nil
}
}
return nil, fmt.Errorf("cannot determine environment: %s",
metadata.Compute.Environment)
}

// implementing azureHostResolver for ensureWorkspaceUrl to work
func (c AzureMsiCredentials) tokenSourceFor(_ context.Context, cfg *Config, _ azureEnvironment, resource string) oauth2.TokenSource {
func (c AzureMsiCredentials) tokenSourceFor(_ context.Context, cfg *Config, _, resource string) oauth2.TokenSource {
return azureMsiTokenSource{
resource: resource,
clientId: cfg.AzureClientID,
Expand Down
8 changes: 8 additions & 0 deletions config/auth_permutations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ func (cf configFixture) configureProviderAndReturnConfig(t *testing.T) (*Config,
AzureResourceID: cf.AzureResourceID,
AuthType: cf.AuthType,
}
if client.IsAzure() {
client.DatabricksEnvironments = append(client.DatabricksEnvironments, DatabricksEnvironment{
Cloud: CloudAzure,
dnsZone: cf.Host,
azureApplicationID: "abc",
azureEnvironment: &publicCloud,
})
}
err := client.Authenticate(&http.Request{Header: http.Header{}})
if err != nil {
return nil, err
Expand Down
55 changes: 8 additions & 47 deletions config/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package config
import (
"context"
"fmt"
"strings"

"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
Expand All @@ -20,73 +19,45 @@ type azureEnvironment struct {
// based on github.com/Azure/go-autorest/autorest/azure/azureEnvironments.go
var (
publicCloud = azureEnvironment{
Name: "AzurePublicCloud",
Name: "PUBLIC",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ServiceManagementEndpoint: "https://management.core.windows.net/",
ResourceManagerEndpoint: "https://management.azure.com/",
ActiveDirectoryEndpoint: "https://login.microsoftonline.com/",
}

usGovernmentCloud = azureEnvironment{
Name: "AzureUSGovernmentCloud",
Name: "USGOVERNMENT",
ServiceManagementEndpoint: "https://management.core.usgovcloudapi.net/",
ResourceManagerEndpoint: "https://management.usgovcloudapi.net/",
ActiveDirectoryEndpoint: "https://login.microsoftonline.us/",
}

chinaCloud = azureEnvironment{
Name: "AzureChinaCloud",
Name: "CHINA",
ServiceManagementEndpoint: "https://management.core.chinacloudapi.cn/",
ResourceManagerEndpoint: "https://management.chinacloudapi.cn/",
ActiveDirectoryEndpoint: "https://login.chinacloudapi.cn/",
}

germanCloud = azureEnvironment{
Name: "AzureGermanCloud",
ServiceManagementEndpoint: "https://management.core.cloudapi.de/",
ResourceManagerEndpoint: "https://management.microsoftazure.de/",
ActiveDirectoryEndpoint: "https://login.microsoftonline.de/",
}

azureEnvironments = map[string]azureEnvironment{
"CHINA": chinaCloud,
"GERMAN": germanCloud,
"PUBLIC": publicCloud,
"USGOVERNMENT": usGovernmentCloud,
}
)

func (c *Config) GetAzureEnvironment() (azureEnvironment, error) {
if c.AzureEnvironment == "" {
c.AzureEnvironment = "public"
}
env, ok := azureEnvironments[strings.ToUpper(c.AzureEnvironment)]
if !ok {
return env, fmt.Errorf("azure environment not found: %s", c.AzureEnvironment)
}
return env, nil
}

type azureHostResolver interface {
tokenSourceFor(ctx context.Context, cfg *Config, env azureEnvironment, resource string) oauth2.TokenSource
tokenSourceFor(ctx context.Context, cfg *Config, aadEndpoint, resource string) oauth2.TokenSource
}

func (c *Config) azureEnsureWorkspaceUrl(ctx context.Context, ahr azureHostResolver) error {
if c.AzureResourceID == "" || c.Host != "" {
return nil
}
env, err := c.GetAzureEnvironment()
if err != nil {
return err
}
azureEnv := c.Environment().azureEnvironment
// azure resource ID can also be used in lieu of host by some of the clients, like Terraform
management := ahr.tokenSourceFor(ctx, c, env, env.ResourceManagerEndpoint)
management := ahr.tokenSourceFor(ctx, c, azureEnv.ActiveDirectoryEndpoint, azureEnv.ResourceManagerEndpoint)
var workspaceMetadata struct {
Properties struct {
WorkspaceURL string `json:"workspaceUrl"`
} `json:"properties"`
}
requestURL := env.ResourceManagerEndpoint + c.AzureResourceID + "?api-version=2018-04-01"
err = httpclient.DefaultClient.Do(ctx, "GET", requestURL,
requestURL := azureEnv.ResourceManagerEndpoint + c.AzureResourceID + "?api-version=2018-04-01"
err := httpclient.DefaultClient.Do(ctx, "GET", requestURL,
httpclient.WithResponseUnmarshal(&workspaceMetadata),
httpclient.WithTokenSource(management),
)
Expand All @@ -97,13 +68,3 @@ func (c *Config) azureEnsureWorkspaceUrl(ctx context.Context, ahr azureHostResol
logger.Debugf(ctx, "Discovered workspace url: %s", c.Host)
return nil
}

// Resource ID of the Azure application we need to log in.
const azureDatabricksLoginAppID string = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d"

func (c *Config) getAzureLoginAppID() string {
if c.AzureLoginAppID != "" {
return c.AzureLoginAppID
}
return azureDatabricksLoginAppID
}
19 changes: 0 additions & 19 deletions config/azure_test.go

This file was deleted.

23 changes: 15 additions & 8 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,15 @@ type Config struct {
AzureClientID string `name:"azure_client_id" env:"ARM_CLIENT_ID" auth:"azure"`
AzureTenantID string `name:"azure_tenant_id" env:"ARM_TENANT_ID" auth:"azure"`

// AzureEnvironment (Public, UsGov, China, Germany) has specific set of API endpoints.
// AzureEnvironment (PUBLIC, USGOVERNMENT, CHINA) has specific set of API endpoints. Starting from v0.26.0,
// the environment is determined based on the workspace hostname, if it's specified.
AzureEnvironment string `name:"azure_environment" env:"ARM_ENVIRONMENT"`

// Azure Login Application ID. Must be set if authenticating for non-production workspaces.
// Azure Login Application ID. Must be set if authenticating for non-production workspaces. Starting from v0.26.0,
// the correct Azure Login App ID is determined based on the Azure Databricks Workspace hostname.
//
// Deprecated: this configuration property no longer has any effect and will be removed in the future
// versions of Go SDK.
AzureLoginAppID string `name:"azure_login_app_id" env:"DATABRICKS_AZURE_LOGIN_APP_ID" auth:"azure"`

ClientID string `name:"client_id" env:"DATABRICKS_CLIENT_ID" auth:"oauth"`
Expand Down Expand Up @@ -108,6 +113,9 @@ type Config struct {
// HTTPTransport can be overriden for unit testing and together with tooling like https://github.com/google/go-replayers
HTTPTransport http.RoundTripper

// Metadata about the environment, where Databricks is deployed. Reserved for internal use.
DatabricksEnvironments []DatabricksEnvironment

Loaders []Loader

// marker for configuration resolving
Expand Down Expand Up @@ -142,16 +150,15 @@ func (c *Config) Authenticate(r *http.Request) error {

// IsAzure returns if the client is configured for Azure Databricks.
func (c *Config) IsAzure() bool {
isAzureHost := strings.Contains(c.Host, ".azuredatabricks.net") ||
strings.Contains(c.Host, "databricks.azure.cn") ||
strings.Contains(c.Host, ".databricks.azure.us")

return isAzureHost || c.AzureResourceID != ""
if c.AzureResourceID != "" {
nfx marked this conversation as resolved.
Show resolved Hide resolved
return true
}
return c.Environment().Cloud == CloudAzure
}

// IsGcp returns if the client is configured for Databricks on Google Cloud.
func (c *Config) IsGcp() bool {
return strings.Contains(c.Host, ".gcp.databricks.com")
return c.Environment().Cloud == CloudGCP
}

// IsAws returns if the client is configured for Databricks on AWS.
Expand Down
Loading
Loading