Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix host resolution order in auth login #1370

Merged
merged 25 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions cmd/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package auth

import (
"context"
"fmt"

"github.com/databricks/cli/libs/auth"
"github.com/databricks/cli/libs/cmdio"
Expand Down Expand Up @@ -34,25 +35,23 @@ GCP: https://docs.gcp.databricks.com/dev-tools/auth/index.html`,
}

func promptForHost(ctx context.Context) (string, error) {
if !cmdio.IsInTTY(ctx) {
return "", fmt.Errorf("the command is being run in a non-interactive environment, please specify a host using --host")
}

prompt := cmdio.Prompt(ctx)
prompt.Label = "Databricks Host (e.g. https://<databricks-instance>.cloud.databricks.com)"
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
// Validate?
host, err := prompt.Run()
if err != nil {
return "", err
}
return host, nil
return prompt.Run()
}

func promptForAccountID(ctx context.Context) (string, error) {
if !cmdio.IsInTTY(ctx) {
return "", fmt.Errorf("the command is being run in a non-interactive environment, please specify an account ID using --account-id")
}

prompt := cmdio.Prompt(ctx)
prompt.Label = "Databricks Account ID"
prompt.Label = "Databricks account id"
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
prompt.Default = ""
prompt.AllowEdit = true
// Validate?
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
accountId, err := prompt.Run()
if err != nil {
return "", err
}
return accountId, nil
return prompt.Run()
}
71 changes: 48 additions & 23 deletions cmd/auth/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,16 @@ import (
"github.com/spf13/cobra"
)

func configureHost(ctx context.Context, persistentAuth *auth.PersistentAuth, args []string, argIndex int) error {
if len(args) > argIndex {
persistentAuth.Host = args[argIndex]
return nil
func promptForProfile(ctx context.Context, dv string) (string, error) {
if !cmdio.IsInTTY(ctx) {
return "", fmt.Errorf("the command is being run in a non-interactive environment, please specify a profile using --profile")
}

host, err := promptForHost(ctx)
if err != nil {
return err
}
persistentAuth.Host = host
return nil
prompt := cmdio.Prompt(ctx)
prompt.Label = "Databricks profile name"
prompt.Default = dv
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
prompt.AllowEdit = true
return prompt.Run()
}

const minimalDbConnectVersion = "13.1"
Expand Down Expand Up @@ -93,23 +91,18 @@ depends on the existing profiles you have set in your configuration file

cmd.RunE = func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
profileName := cmd.Flag("profile").Value.String()

var profileName string
profileFlag := cmd.Flag("profile")
if profileFlag != nil && profileFlag.Value.String() != "" {
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
profileName = profileFlag.Value.String()
} else if cmdio.IsInTTY(ctx) {
prompt := cmdio.Prompt(ctx)
prompt.Label = "Databricks Profile Name"
prompt.Default = persistentAuth.ProfileName()
prompt.AllowEdit = true
profile, err := prompt.Run()
// If the user has not specified a profile name, prompt for one.
if profileName == "" {
var err error
profileName, err = promptForProfile(ctx, persistentAuth.ProfileName())
if err != nil {
return err
}
profileName = profile
}

// Set the host and account-id based on the provided arguments and flags.
err := setHostAndAccountId(ctx, profileName, persistentAuth, args)
if err != nil {
return err
Expand Down Expand Up @@ -167,7 +160,23 @@ depends on the existing profiles you have set in your configuration file
return cmd
}

// Sets the host in the persistentAuth object based on the provided arguments and flags.
// Follows the following precedence:
// 1. [HOST] (first positional argument) or --host flag. Error if both are specified.
// 2. Profile host, if available.
// 3. Prompt the user for the host.
//
// Set the account in the persistentAuth object based on the flags.
// Follows the following precedence:
// 1. --account-id flag.
// 2. account-id from the specified profile, if available.
// 3. Prompt the user for the account-id.
func setHostAndAccountId(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth, args []string) error {
// If both [HOST] and --host are provided, return an error.
if len(args) > 0 && persistentAuth.Host != "" {
return fmt.Errorf("please only provide a host as an argument or a flag, not both")
}

profiler := profile.GetProfiler(ctx)
// If the chosen profile has a hostname and the user hasn't specified a host, infer the host from the profile.
profiles, err := profiler.LoadProfiles(ctx, profile.WithName(profileName))
Expand All @@ -176,18 +185,34 @@ func setHostAndAccountId(ctx context.Context, profileName string, persistentAuth
return err
}

if persistentAuth.Host == "" {
// If [HOST] is provided, set the host to the provided positional argument.
if len(args) > 0 && persistentAuth.Host == "" {
persistentAuth.Host = args[0]
}

// If neither [HOST] nor --host are provided, and the profile has a host, use it.
// Otherwise, prompt the user for a host.
if len(args) == 0 && persistentAuth.Host == "" {
if len(profiles) > 0 && profiles[0].Host != "" {
persistentAuth.Host = profiles[0].Host
} else {
configureHost(ctx, persistentAuth, args, 0)
hostName, err := promptForHost(ctx)
if err != nil {
return err
}
persistentAuth.Host = hostName
}
}
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved

// If the account-id was not provided as a cmd line flag, try to read it from
// the specified profile.
isAccountClient := (&config.Config{Host: persistentAuth.Host}).IsAccountClient()
if isAccountClient && persistentAuth.AccountID == "" {
if len(profiles) > 0 && profiles[0].AccountID != "" {
persistentAuth.AccountID = profiles[0].AccountID
} else {
// Prompt user for the account-id if it we could not get it from a
// profile.
accountId, err := promptForAccountID(ctx)
if err != nil {
return err
Expand Down
68 changes: 68 additions & 0 deletions cmd/auth/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"testing"

"github.com/databricks/cli/libs/auth"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/env"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestSetHostDoesNotFailWithNoDatabrickscfg(t *testing.T) {
Expand All @@ -15,3 +17,69 @@ func TestSetHostDoesNotFailWithNoDatabrickscfg(t *testing.T) {
err := setHostAndAccountId(ctx, "foo", &auth.PersistentAuth{Host: "test"}, []string{})
assert.NoError(t, err)
}

func TestSetHost(t *testing.T) {
var persistentAuth auth.PersistentAuth
t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/.databrickscfg")
ctx, _ := cmdio.SetupTest(context.Background())

// Test error when both flag and argument are provided
persistentAuth.Host = "val from --host"
err := setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{"val from [HOST]"})
assert.EqualError(t, err, "please only provide a host as an argument or a flag, not both")

// Test setting host from flag
persistentAuth.Host = "val from --host"
err = setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{})
assert.NoError(t, err)
assert.Equal(t, "val from --host", persistentAuth.Host)

// Test setting host from argument
persistentAuth.Host = ""
err = setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{"val from [HOST]"})
assert.NoError(t, err)
assert.Equal(t, "val from [HOST]", persistentAuth.Host)

// Test setting host from profile
persistentAuth.Host = ""
err = setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{})
assert.NoError(t, err)
assert.Equal(t, "https://www.host1.com", persistentAuth.Host)

// Test setting host from profile
persistentAuth.Host = ""
err = setHostAndAccountId(ctx, "profile-2", &persistentAuth, []string{})
assert.NoError(t, err)
assert.Equal(t, "https://www.host2.com", persistentAuth.Host)

// Test host is not set. Should prompt.
persistentAuth.Host = ""
err = setHostAndAccountId(ctx, "", &persistentAuth, []string{})
assert.EqualError(t, err, "the command is being run in a non-interactive environment, please specify a host using --host")
}

func TestSetAccountId(t *testing.T) {
var persistentAuth auth.PersistentAuth
t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/.databrickscfg")
ctx, _ := cmdio.SetupTest(context.Background())

// Test setting account-id from flag
persistentAuth.AccountID = "val from --account-id"
err := setHostAndAccountId(ctx, "account-profile", &persistentAuth, []string{})
assert.NoError(t, err)
assert.Equal(t, "https://accounts.cloud.databricks.com", persistentAuth.Host)
assert.Equal(t, "val from --account-id", persistentAuth.AccountID)

// Test setting account_id from profile
persistentAuth.AccountID = ""
err = setHostAndAccountId(ctx, "account-profile", &persistentAuth, []string{})
require.NoError(t, err)
assert.Equal(t, "https://accounts.cloud.databricks.com", persistentAuth.Host)
assert.Equal(t, "id-from-profile", persistentAuth.AccountID)

// Neither flag nor profile account-id is set, should prompt
persistentAuth.AccountID = ""
persistentAuth.Host = "https://accounts.cloud.databricks.com"
err = setHostAndAccountId(ctx, "", &persistentAuth, []string{})
assert.EqualError(t, err, "the command is being run in a non-interactive environment, please specify an account ID using --account-id")
}
9 changes: 9 additions & 0 deletions cmd/auth/testdata/.databrickscfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[profile-1]
host = https://www.host1.com

[profile-2]
host = https://www.host2.com

[account-profile]
host = https://accounts.cloud.databricks.com
account_id = id-from-profile
1 change: 0 additions & 1 deletion libs/auth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ func (a *PersistentAuth) Load(ctx context.Context) (*oauth2.Token, error) {
}

func (a *PersistentAuth) ProfileName() string {
// TODO: get profile name from interactive input
if a.AccountID != "" {
return fmt.Sprintf("ACCOUNT-%s", a.AccountID)
}
Expand Down
Loading