Skip to content

Commit

Permalink
[KEEP] Introduce DatabricksEnvironment and fix Azure MSI auth from …
Browse files Browse the repository at this point in the history
…ACR, where IMDS doesn't give host environment information
  • Loading branch information
nfx committed Nov 24, 2023
1 parent b0f3c83 commit 18f31cf
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 84 deletions.
8 changes: 6 additions & 2 deletions config/auth_azure_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (c AzureCliCredentials) Name() string {

// implementing azureHostResolver for ensureWorkspaceUrl to work
func (c AzureCliCredentials) tokenSourceFor(
ctx context.Context, cfg *Config, env azureEnvironment, resource string) oauth2.TokenSource {
ctx context.Context, cfg *Config, _, resource string) oauth2.TokenSource {
return &azureCliTokenSource{resource: resource}
}

Expand Down Expand Up @@ -62,8 +62,12 @@ func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (func(*
if !cfg.IsAzure() {
return nil, nil
}
env, err := cfg.Environment()
if err != nil {
return nil, fmt.Errorf("environment: %w", err)
}
// Eagerly get a token to fail fast in case the user is not logged in with the Azure CLI.
ts := &azureCliTokenSource{cfg.getAzureLoginAppID(), cfg.AzureResourceID}
ts := &azureCliTokenSource{env.AzureApplicationID, cfg.AzureResourceID}
t, err := ts.Token()
if err != nil {
if strings.Contains(err.Error(), "No subscription found") {
Expand Down
20 changes: 11 additions & 9 deletions config/auth_azure_client_secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ func (c AzureClientSecretCredentials) Name() string {
}

func (c AzureClientSecretCredentials) tokenSourceFor(
ctx context.Context, cfg *Config, env azureEnvironment, resource string) oauth2.TokenSource {
ctx context.Context, cfg *Config, aadEndpoint, resource string) oauth2.TokenSource {
return (&clientcredentials.Config{
ClientID: cfg.AzureClientID,
ClientSecret: cfg.AzureClientSecret,
TokenURL: fmt.Sprintf("%s%s/oauth2/token", env.ActiveDirectoryEndpoint, cfg.AzureTenantID),
TokenURL: fmt.Sprintf("%s%s/oauth2/token", aadEndpoint, cfg.AzureTenantID),
EndpointParams: url.Values{
"resource": []string{resource},
},
Expand All @@ -43,18 +43,20 @@ func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config
if !cfg.IsAzure() {
return nil, nil
}
env, err := cfg.GetAzureEnvironment()
if err != nil {
return nil, err
}
ctx = httpclient.DefaultClient.InContextForOAuth2(ctx)
err = cfg.azureEnsureWorkspaceUrl(ctx, c)
err := cfg.azureEnsureWorkspaceUrl(ctx, c)
if err != nil {
return nil, fmt.Errorf("resolve host: %w", err)
}
logger.Infof(ctx, "Generating AAD token for Service Principal (%s)", cfg.AzureClientID)
refreshCtx := context.Background()
inner := azureReuseTokenSource(nil, c.tokenSourceFor(refreshCtx, cfg, env, cfg.getAzureLoginAppID()))
management := azureReuseTokenSource(nil, c.tokenSourceFor(refreshCtx, cfg, env, env.ServiceManagementEndpoint))
env, err := cfg.Environment()
if err != nil {
return nil, fmt.Errorf("environment: %w", err)
}
aadEndpoint := env.AzureEnvironment.ActiveDirectoryEndpoint
managementEndpoint := env.AzureEnvironment.ServiceManagementEndpoint
inner := azureReuseTokenSource(nil, c.tokenSourceFor(refreshCtx, cfg, aadEndpoint, env.AzureApplicationID))
management := azureReuseTokenSource(nil, c.tokenSourceFor(refreshCtx, cfg, aadEndpoint, managementEndpoint))
return azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken)), nil
}
43 changes: 2 additions & 41 deletions config/auth_azure_msi.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (func(*
if !cfg.IsAzure() || !cfg.AzureUseMSI || (cfg.AzureResourceID == "" && !cfg.IsAccountClient()) {
return nil, nil
}
env, err := c.getInstanceEnvironment(ctx)
env, err := cfg.GetAzureEnvironment()
if err != nil {
return nil, err
}
Expand All @@ -54,47 +54,8 @@ func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (func(*
return azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken)), nil
}

func (c AzureMsiCredentials) getInstanceEnvironment(ctx context.Context) (*azureEnvironment, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
fmt.Sprintf("%s/instance", instanceMetadataPrefix), nil)
if err != nil {
return nil, fmt.Errorf("metadata request: %w", err)
}
query := req.URL.Query()
query.Add("api-version", "2021-12-13")
req.URL.RawQuery = query.Encode()
req.Header.Add("Metadata", "true")
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("metadata response: %w", err)
}
raw, err := io.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("metadata read: %w", err)
}
if res.StatusCode != 200 {
return nil, fmt.Errorf("metadata error: %s", raw)
}
var metadata struct {
Compute struct {
Environment string `json:"azEnvironment"`
} `json:"compute"`
}
err = json.Unmarshal(raw, &metadata)
if err != nil {
return nil, fmt.Errorf("metadata parse: %w", err)
}
for _, v := range azureEnvironments {
if v.Name == metadata.Compute.Environment {
return &v, nil
}
}
return nil, fmt.Errorf("cannot determine environment: %s",
metadata.Compute.Environment)
}

// implementing azureHostResolver for ensureWorkspaceUrl to work
func (c AzureMsiCredentials) tokenSourceFor(_ context.Context, cfg *Config, _ azureEnvironment, resource string) oauth2.TokenSource {
func (c AzureMsiCredentials) tokenSourceFor(_ context.Context, cfg *Config, _, resource string) oauth2.TokenSource {
return azureMsiTokenSource{
resource: resource,
clientId: cfg.AzureClientID,
Expand Down
8 changes: 8 additions & 0 deletions config/auth_permutations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ func (cf configFixture) configureProviderAndReturnConfig(t *testing.T) (*Config,
AzureResourceID: cf.AzureResourceID,
AuthType: cf.AuthType,
}
if client.IsAzure() {
client.DatabricksEnvironments = append(client.DatabricksEnvironments, DatabricksEnvironment{
Cloud: CloudAzure,
DnsZone: cf.Host,
AzureApplicationID: "abc",
AzureEnvironment: &publicCloud,
})
}
err := client.Authenticate(&http.Request{Header: http.Header{}})
if err != nil {
return nil, err
Expand Down
30 changes: 6 additions & 24 deletions config/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package config
import (
"context"
"fmt"
"strings"

"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
Expand Down Expand Up @@ -39,35 +38,18 @@ var (
ResourceManagerEndpoint: "https://management.chinacloudapi.cn/",
ActiveDirectoryEndpoint: "https://login.chinacloudapi.cn/",
}

germanCloud = azureEnvironment{
Name: "AzureGermanCloud",
ServiceManagementEndpoint: "https://management.core.cloudapi.de/",
ResourceManagerEndpoint: "https://management.microsoftazure.de/",
ActiveDirectoryEndpoint: "https://login.microsoftonline.de/",
}

azureEnvironments = map[string]azureEnvironment{
"CHINA": chinaCloud,
"GERMAN": germanCloud,
"PUBLIC": publicCloud,
"USGOVERNMENT": usGovernmentCloud,
}
)

func (c *Config) GetAzureEnvironment() (azureEnvironment, error) {
if c.AzureEnvironment == "" {
c.AzureEnvironment = "public"
}
env, ok := azureEnvironments[strings.ToUpper(c.AzureEnvironment)]
if !ok {
return env, fmt.Errorf("azure environment not found: %s", c.AzureEnvironment)
env, err := c.Environment()
if err != nil {
return publicCloud, fmt.Errorf("no Databricks environment: %w", err)
}
return env, nil
return *env.AzureEnvironment, nil
}

type azureHostResolver interface {
tokenSourceFor(ctx context.Context, cfg *Config, env azureEnvironment, resource string) oauth2.TokenSource
tokenSourceFor(ctx context.Context, cfg *Config, aadEndpoint, resource string) oauth2.TokenSource
}

func (c *Config) azureEnsureWorkspaceUrl(ctx context.Context, ahr azureHostResolver) error {
Expand All @@ -79,7 +61,7 @@ func (c *Config) azureEnsureWorkspaceUrl(ctx context.Context, ahr azureHostResol
return err
}
// azure resource ID can also be used in lieu of host by some of the clients, like Terraform
management := ahr.tokenSourceFor(ctx, c, env, env.ResourceManagerEndpoint)
management := ahr.tokenSourceFor(ctx, c, env.ActiveDirectoryEndpoint, env.ResourceManagerEndpoint)
var workspaceMetadata struct {
Properties struct {
WorkspaceURL string `json:"workspaceUrl"`
Expand Down
26 changes: 18 additions & 8 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ type Config struct {
AzureClientID string `name:"azure_client_id" env:"ARM_CLIENT_ID" auth:"azure"`
AzureTenantID string `name:"azure_tenant_id" env:"ARM_TENANT_ID" auth:"azure"`

// AzureEnvironment (Public, UsGov, China, Germany) has specific set of API endpoints.
// [Deprecated] AzureEnvironment (Public, UsGov, China, Germany) has specific set of API endpoints.
AzureEnvironment string `name:"azure_environment" env:"ARM_ENVIRONMENT"`

// Azure Login Application ID. Must be set if authenticating for non-production workspaces.
// [Deprecated] Azure Login Application ID. Must be set if authenticating for non-production workspaces.
AzureLoginAppID string `name:"azure_login_app_id" env:"DATABRICKS_AZURE_LOGIN_APP_ID" auth:"azure"`

ClientID string `name:"client_id" env:"DATABRICKS_CLIENT_ID" auth:"oauth"`
Expand Down Expand Up @@ -108,6 +108,9 @@ type Config struct {
// HTTPTransport can be overriden for unit testing and together with tooling like https://github.com/google/go-replayers
HTTPTransport http.RoundTripper

// Metadata about the environment, where Databricks is deployed. Reserved for internal use.
DatabricksEnvironments []DatabricksEnvironment

Loaders []Loader

// marker for configuration resolving
Expand Down Expand Up @@ -142,16 +145,23 @@ func (c *Config) Authenticate(r *http.Request) error {

// IsAzure returns if the client is configured for Azure Databricks.
func (c *Config) IsAzure() bool {
isAzureHost := strings.Contains(c.Host, ".azuredatabricks.net") ||
strings.Contains(c.Host, "databricks.azure.cn") ||
strings.Contains(c.Host, ".databricks.azure.us")

return isAzureHost || c.AzureResourceID != ""
if c.AzureResourceID != "" {
return true
}
env, err := c.Environment()
if err != nil {
return false
}
return env.Cloud == CloudAzure
}

// IsGcp returns if the client is configured for Databricks on Google Cloud.
func (c *Config) IsGcp() bool {
return strings.Contains(c.Host, ".gcp.databricks.com")
env, err := c.Environment()
if err != nil {
return false
}
return env.Cloud == CloudGCP
}

// IsAws returns if the client is configured for Databricks on AWS.
Expand Down
56 changes: 56 additions & 0 deletions config/environments.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package config

import (
"fmt"
"strings"
)

type Cloud string

const (
CloudUnspecified Cloud = "Unspecified"
CloudAWS Cloud = "AWS"
CloudAzure Cloud = "Azure"
CloudGCP Cloud = "GCP"
)

type DatabricksEnvironment struct {
Cloud Cloud
DnsZone string
AzureApplicationID string
AzureEnvironment *azureEnvironment
}

func (de DatabricksEnvironment) DeploymentURL(name string) string {
return fmt.Sprintf("https://%s%s", name, de.DnsZone)
}

var envs = []DatabricksEnvironment{
{Cloud: CloudUnspecified, DnsZone: "localhost"},

{Cloud: CloudAWS, DnsZone: ".dev.databricks.com"},
{Cloud: CloudAWS, DnsZone: ".staging.cloud.databricks.com"},
{Cloud: CloudAWS, DnsZone: ".cloud.databricks.com"},
{Cloud: CloudAWS, DnsZone: ".cloud.databricks.us"},

{Cloud: CloudAzure, DnsZone: ".dev.azuredatabricks.net", AzureApplicationID: "62a912ac-b58e-4c1d-89ea-b2dbfc7358fc", AzureEnvironment: &publicCloud},
{Cloud: CloudAzure, DnsZone: ".staging.azuredatabricks.net", AzureApplicationID: "4a67d088-db5c-48f1-9ff2-0aace800ae68", AzureEnvironment: &publicCloud},
{Cloud: CloudAzure, DnsZone: ".azuredatabricks.net", AzureApplicationID: "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d", AzureEnvironment: &publicCloud},
{Cloud: CloudAzure, DnsZone: ".databricks.azure.us", AzureApplicationID: "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d", AzureEnvironment: &usGovernmentCloud},
{Cloud: CloudAzure, DnsZone: ".databricks.azure.cn", AzureApplicationID: "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d", AzureEnvironment: &chinaCloud},

{Cloud: CloudGCP, DnsZone: ".dev.gcp.databricks.com"},
{Cloud: CloudGCP, DnsZone: ".staging.gcp.databricks.com"},
{Cloud: CloudGCP, DnsZone: ".gcp.databricks.com"},
}

func (c *Config) Environment() (*DatabricksEnvironment, error) {
hostname := c.CanonicalHostName()
for _, e := range append(c.DatabricksEnvironments, envs...) {
if strings.HasSuffix(hostname, e.DnsZone) {
return &e, nil
}
}
// TODO: do we return default one or do we return error?...
return nil, fmt.Errorf("cannot find DatabricksEnvironment for %s", hostname)
}

0 comments on commit 18f31cf

Please sign in to comment.