diff --git a/.codegen/accounts.go.tmpl b/.codegen/accounts.go.tmpl index 6138f84bb..24c2a8646 100644 --- a/.codegen/accounts.go.tmpl +++ b/.codegen/accounts.go.tmpl @@ -35,7 +35,7 @@ func NewAccountClient(c ...*Config) (*AccountClient, error) { if err != nil { return nil, err } - if cfg.AccountID == "" || !cfg.IsAccountClient() { + if !cfg.IsAccountClient() { return nil, ErrNotAccountClient } apiClient, err := client.New(cfg) diff --git a/account_client.go b/account_client.go index 2596a826e..4bfb8c28f 100755 --- a/account_client.go +++ b/account_client.go @@ -282,7 +282,7 @@ func NewAccountClient(c ...*Config) (*AccountClient, error) { if err != nil { return nil, err } - if cfg.AccountID == "" || !cfg.IsAccountClient() { + if !cfg.IsAccountClient() { return nil, ErrNotAccountClient } apiClient, err := client.New(cfg) diff --git a/common/environment/environments.go b/common/environment/environments.go index 57acd99d7..8c019c90d 100644 --- a/common/environment/environments.go +++ b/common/environment/environments.go @@ -45,6 +45,10 @@ func (de DatabricksEnvironment) AzureActiveDirectoryEndpoint() string { return de.AzureEnvironment.ActiveDirectoryEndpoint } +func (de DatabricksEnvironment) AccountsHost() string { + return "https://accounts" + de.DnsZone +} + // we default to AWS Prod environment since this case will be a hit for PVC func DefaultEnvironment() DatabricksEnvironment { return DatabricksEnvironment{ diff --git a/config/auth_databricks_cli.go b/config/auth_databricks_cli.go index 295557591..cd7cc554c 100644 --- a/config/auth_databricks_cli.go +++ b/config/auth_databricks_cli.go @@ -27,7 +27,7 @@ func (c DatabricksCliCredentials) Configure(ctx context.Context, cfg *Config) (c return nil, nil } - ts, err := newDatabricksCliTokenSource(cfg) + ts, err := newDatabricksCliTokenSource(ctx, cfg) if err != nil { if errors.Is(err, exec.ErrNotFound) { logger.Debugf(ctx, "Most likely the Databricks CLI is not installed") @@ -61,17 +61,12 @@ func (c DatabricksCliCredentials) Configure(ctx context.Context, cfg *Config) (c var errLegacyDatabricksCli = errors.New("legacy Databricks CLI detected") type databricksCliTokenSource struct { + ctx context.Context name string - args []string + cfg *Config } -func newDatabricksCliTokenSource(cfg *Config) (*databricksCliTokenSource, error) { - args := []string{"auth", "token", "--host", cfg.Host} - - if cfg.IsAccountClient() { - args = append(args, "--account-id", cfg.AccountID) - } - +func newDatabricksCliTokenSource(ctx context.Context, cfg *Config) (*databricksCliTokenSource, error) { databricksCliPath := cfg.DatabricksCliPath if databricksCliPath == "" { databricksCliPath = "databricks" @@ -101,16 +96,43 @@ func newDatabricksCliTokenSource(cfg *Config) (*databricksCliTokenSource, error) return nil, errLegacyDatabricksCli } - return &databricksCliTokenSource{name: path, args: args}, nil + return &databricksCliTokenSource{ctx: ctx, name: path, cfg: cfg}, nil } func (ts *databricksCliTokenSource) Token() (*oauth2.Token, error) { - out, err := exec.Command(ts.name, ts.args...).Output() + baseArgs := []string{"auth", "token"} + if ts.cfg.IsAccountClient() { + args := append(baseArgs, "--host", ts.cfg.Host, "--account-id", ts.cfg.AccountID) + return ts.tokenInner(args) + } + // Try workspace-level auth first, falling back to account-level auth if account ID is available + args := append(baseArgs, "--host", ts.cfg.Host) + t, wsErr := ts.tokenInner(args) + if wsErr == nil { + return t, nil + } + if ts.cfg.AccountID == "" { + return nil, wsErr + } + logger.Debugf(ts.ctx, "account ID available, falling back to account-level authentication") + args = append(baseArgs, "--host", ts.cfg.Environment().AccountsHost(), "--account-id", ts.cfg.AccountID) + t, acctErr := ts.tokenInner(args) + if acctErr == nil { + return t, nil + } + return nil, acctErr +} + +func (ts *databricksCliTokenSource) tokenInner(args []string) (*oauth2.Token, error) { + logger.Debugf(ts.ctx, "running command: '%s %s'", ts.name, strings.Join(args, " ")) + out, err := exec.Command(ts.name, args...).Output() if ee, ok := err.(*exec.ExitError); ok { + logger.Debugf(ts.ctx, "command '%s %s' failed: %s", ts.name, strings.Join(args, " "), string(ee.Stderr)) return nil, fmt.Errorf("cannot get access token: %s", string(ee.Stderr)) } if err != nil { - return nil, fmt.Errorf("cannot get access token: %v", err) + logger.Debugf(ts.ctx, "command '%s %s' failed to run: %w", ts.name, strings.Join(args, " "), err) + return nil, fmt.Errorf("cannot get access token: %w", err) } var t oauth2.Token err = json.Unmarshal(out, &t) diff --git a/config/auth_databricks_cli_test.go b/config/auth_databricks_cli_test.go index 5566d93c0..b46c342a4 100644 --- a/config/auth_databricks_cli_test.go +++ b/config/auth_databricks_cli_test.go @@ -13,40 +13,54 @@ import ( var cliDummy = &Config{Host: "https://abc.cloud.databricks.com/"} -func writeSmallDummyExecutable(t *testing.T, path string) { - f, err := os.Create(filepath.Join(path, "databricks")) - require.NoError(t, err) - defer f.Close() - err = os.Chmod(f.Name(), 0755) - require.NoError(t, err) - _, err = f.WriteString("#!/bin/sh\necho hello world\n") - require.NoError(t, err) +const smallExecutable = `#!/bin/sh +echo hello world +` + +const largeExecutable = `#!/bin/sh +cat <&2 + touch "$(dirname "$0")/.token_file" + exit 1 +fi - f.WriteString(` cat < 0 { + err = f.Truncate(int64(truncateSize)) + require.NoError(t, err) + } } func TestDatabricksCliCredentials_SkipAzure(t *testing.T) { @@ -73,7 +87,7 @@ func TestDatabricksCliCredentials_NotInstalled(t *testing.T) { func TestDatabricksCliCredentials_InstalledLegacy(t *testing.T) { // Create a dummy databricks executable. tmp := t.TempDir() - writeSmallDummyExecutable(t, tmp) + writeDummyExecutable(t, tmp, smallExecutable, 0) t.Setenv("PATH", tmp) aa := DatabricksCliCredentials{} @@ -85,7 +99,7 @@ func TestDatabricksCliCredentials_InstalledLegacyWithSymlink(t *testing.T) { // Create a dummy databricks executable. tmp1 := t.TempDir() tmp2 := t.TempDir() - writeSmallDummyExecutable(t, tmp1) + writeDummyExecutable(t, tmp1, smallExecutable, 0) os.Symlink(filepath.Join(tmp1, "databricks"), filepath.Join(tmp2, "databricks")) t.Setenv("PATH", tmp2+string(os.PathListSeparator)+os.Getenv("PATH")) @@ -99,10 +113,22 @@ func TestDatabricksCliCredentials_InstalledNew(t *testing.T) { // Create a dummy databricks executable. tmp := t.TempDir() - writeLargeDummyExecutable(t, tmp) + writeDummyExecutable(t, tmp, largeExecutable, 1024 * 1024) t.Setenv("PATH", tmp+string(os.PathListSeparator)+os.Getenv("PATH")) aa := DatabricksCliCredentials{} _, err := aa.Configure(context.Background(), cliDummy) require.NoError(t, err) } + +func TestDatabricksCliCredentials_FallbackToAccountLevel(t *testing.T) { + env.CleanupEnvironment(t) + + tmp := t.TempDir() + writeDummyExecutable(t, tmp, failFirstSucceedThereafter, 0) + t.Setenv("PATH", tmp+string(os.PathListSeparator)+os.Getenv("PATH")) + + aa := DatabricksCliCredentials{} + _, err := aa.Configure(context.Background(), cliDummy) + require.NoError(t, err) +} \ No newline at end of file diff --git a/config/auth_m2m.go b/config/auth_m2m.go index 5399228fe..a75221ee4 100644 --- a/config/auth_m2m.go +++ b/config/auth_m2m.go @@ -44,7 +44,7 @@ func (c M2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials func oidcEndpoints(ctx context.Context, cfg *Config) (*oauthAuthorizationServer, error) { prefix := cfg.Host - if cfg.IsAccountClient() && cfg.AccountID != "" { + if cfg.IsAccountClient() { // TODO: technically, we could use the same config profile for both workspace // and account, but we have to add logic for determining accounts host from // workspace host. diff --git a/config/config.go b/config/config.go index 4f29c3746..53a21f366 100644 --- a/config/config.go +++ b/config/config.go @@ -248,18 +248,17 @@ func (c *Config) IsAws() bool { // IsAccountClient returns true if client is configured for Accounts API func (c *Config) IsAccountClient() bool { - if c.AccountID != "" && c.isTesting { + if c.AccountID == "" { + return false + } + if c.isTesting { return true } - - accountsPrefixes := []string{ - "https://accounts.", - "https://accounts-dod.", + if c.Host == c.Environment().AccountsHost() { + return true } - for _, prefix := range accountsPrefixes { - if strings.HasPrefix(c.Host, prefix) { - return true - } + if strings.HasPrefix(c.Host, "https://accounts-dod.") { + return true } return false }