Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix host resolution order in auth login #1370

Merged
merged 25 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions cmd/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package auth

import (
"context"
"errors"
"fmt"
"strings"

"github.com/databricks/cli/libs/auth"
"github.com/databricks/cli/libs/cmdio"
Expand Down Expand Up @@ -34,25 +37,29 @@ GCP: https://docs.gcp.databricks.com/en/dev-tools/auth/index.html`,
}

func promptForHost(ctx context.Context) (string, error) {
if !cmdio.IsInTTY(ctx) {
return "", fmt.Errorf("the command is being run in a non-interactive environment, please specify a host using --host")
}

prompt := cmdio.Prompt(ctx)
prompt.Label = "Databricks Host (e.g. https://<databricks-instance>.cloud.databricks.com)"
// Validate?
host, err := prompt.Run()
if err != nil {
return "", err
prompt.Label = "Databricks Host"
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
prompt.Validate = func(host string) error {
if !strings.HasPrefix(host, "https://") {
return fmt.Errorf("host URL must have a https:// prefix")
}
return nil
}
return host, nil
return prompt.Run()
}

func promptForAccountID(ctx context.Context) (string, error) {
if !cmdio.IsInTTY(ctx) {
return "", errors.New("the command is being run in a non-interactive environment, please specify an account ID using --account-id")
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
}

prompt := cmdio.Prompt(ctx)
prompt.Label = "Databricks Account ID"
prompt.Default = ""
prompt.AllowEdit = true
// Validate?
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
accountId, err := prompt.Run()
if err != nil {
return "", err
}
return accountId, nil
return prompt.Run()
}
95 changes: 63 additions & 32 deletions cmd/auth/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,37 @@ import (
"github.com/spf13/cobra"
)

func configureHost(ctx context.Context, persistentAuth *auth.PersistentAuth, args []string, argIndex int) error {
if len(args) > argIndex {
persistentAuth.Host = args[argIndex]
return nil
func promptForProfile(ctx context.Context, dv string) (string, error) {
if !cmdio.IsInTTY(ctx) {
return "", errors.New("the command is being run in a non-interactive environment, please specify a profile using --profile")
}

host, err := promptForHost(ctx)
prompt := cmdio.Prompt(ctx)
prompt.Label = "Databricks Profile Name"
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
prompt.Default = dv
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
prompt.AllowEdit = true
return prompt.Run()
}

func getHostFromProfile(ctx context.Context, profileName string) (string, error) {
_, profiles, err := databrickscfg.LoadProfiles(ctx, func(p databrickscfg.Profile) bool {
return p.Name == profileName
})

// Tolerate ErrNoConfiguration here, as we will write out a configuration file
// as part of the login flow.
if err != nil && errors.Is(err, databrickscfg.ErrNoConfiguration) {
return "", nil
}
if err != nil {
return err
return "", err
}
persistentAuth.Host = host
return nil

// Return host from profile
if len(profiles) > 0 && profiles[0].Host != "" {
return profiles[0].Host, nil
}
return "", nil
}

const minimalDbConnectVersion = "13.1"
Expand All @@ -39,35 +58,30 @@ func newLoginCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {

var loginTimeout time.Duration
var configureCluster bool
var profileName string
cmd.Flags().DurationVar(&loginTimeout, "timeout", auth.DefaultTimeout,
"Timeout for completing login challenge in the browser")
cmd.Flags().BoolVar(&configureCluster, "configure-cluster", false,
"Prompts to configure cluster")
cmd.Flags().StringVarP(&profileName, "profile", "p", "", `Name of the profile.`)
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved

cmd.RunE = func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()

var profileName string
profileFlag := cmd.Flag("profile")
if profileFlag != nil && profileFlag.Value.String() != "" {
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
profileName = profileFlag.Value.String()
} else if cmdio.IsInTTY(ctx) {
prompt := cmdio.Prompt(ctx)
prompt.Label = "Databricks Profile Name"
prompt.Default = persistentAuth.ProfileName()
prompt.AllowEdit = true
profile, err := prompt.Run()
// If the user has not specified a profile name, prompt for one.
if profileName == "" {
var err error
profileName, err = promptForProfile(ctx, persistentAuth.DefaultProfileName())
if err != nil {
return err
}
profileName = profile
}

// Set the host based on the provided arguments and flags.
err := setHost(ctx, profileName, persistentAuth, args)
if err != nil {
return err
}
defer persistentAuth.Close()
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved

// We need the config without the profile before it's used to initialise new workspace client below.
// Otherwise it will complain about non existing profile because it was not yet saved.
Expand All @@ -91,6 +105,7 @@ func newLoginCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
if err != nil {
return err
}
defer persistentAuth.Close()

if configureCluster {
w, err := databricks.NewWorkspaceClient((*databricks.Config)(&cfg))
Expand Down Expand Up @@ -127,21 +142,37 @@ func newLoginCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
return cmd
}

// Sets the host in the persistentAuth object based on the provided arguments and flags.
// Follows the following precedence:
// 1. [HOST] (first positional argument) or --host flag. Error if both are specified.
// 2. Profile host, if available.
// 3. Prompt the user for the host.
func setHost(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth, args []string) error {
// If the chosen profile has a hostname and the user hasn't specified a host, infer the host from the profile.
_, profiles, err := databrickscfg.LoadProfiles(ctx, func(p databrickscfg.Profile) bool {
return p.Name == profileName
})
// Tolerate ErrNoConfiguration here, as we will write out a configuration as part of the login flow.
if err != nil && !errors.Is(err, databrickscfg.ErrNoConfiguration) {
return err
// If both [HOST] and --host are provided, return an error.
if len(args) > 0 && persistentAuth.Host != "" {
return fmt.Errorf("please only provide a host as an argument or a flag, not both")
}
if persistentAuth.Host == "" {
if len(profiles) > 0 && profiles[0].Host != "" {
persistentAuth.Host = profiles[0].Host
} else {
configureHost(ctx, persistentAuth, args, 0)

// If [HOST] is provided, set the host to the provided positional argument.
if len(args) > 0 && persistentAuth.Host == "" {
persistentAuth.Host = args[0]
}

// If neither [HOST] nor --host are provided, and the profile has a host, use it.
// Otherwise, prompt the user for a host.
if len(args) == 0 && persistentAuth.Host == "" {
hostName, err := getHostFromProfile(ctx, profileName)
if err != nil {
return err
}
if hostName == "" {
var err error
hostName, err = promptForHost(ctx)
if err != nil {
return err
}
}
persistentAuth.Host = hostName
}
return nil
}
35 changes: 35 additions & 0 deletions cmd/auth/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,38 @@ func TestSetHostDoesNotFailWithNoDatabrickscfg(t *testing.T) {
err := setHost(ctx, "foo", &auth.PersistentAuth{Host: "test"}, []string{})
assert.NoError(t, err)
}

func TestSetHost(t *testing.T) {
ctx := context.Background()
var persistentAuth auth.PersistentAuth
t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/.databrickscfg")

// Test error when both flag and argument are provided
persistentAuth.Host = "val from --host"
err := setHost(ctx, "profile-1", &persistentAuth, []string{"val from [HOST]"})
assert.EqualError(t, err, "please only provide a host as an argument or a flag, not both")

// Test setting host from flag
persistentAuth.Host = "val from --host"
err = setHost(ctx, "profile-1", &persistentAuth, []string{})
assert.NoError(t, err)
assert.Equal(t, "val from --host", persistentAuth.Host)

// Test setting host from argument
persistentAuth.Host = ""
err = setHost(ctx, "profile-1", &persistentAuth, []string{"val from [HOST]"})
assert.NoError(t, err)
assert.Equal(t, "val from [HOST]", persistentAuth.Host)

// Test setting host from profile
persistentAuth.Host = ""
err = setHost(ctx, "profile-1", &persistentAuth, []string{})
assert.NoError(t, err)
assert.Equal(t, "https://www.host1.com", persistentAuth.Host)

// Test setting host from profile
persistentAuth.Host = ""
err = setHost(ctx, "profile-2", &persistentAuth, []string{})
assert.NoError(t, err)
assert.Equal(t, "https://www.host2.com", persistentAuth.Host)
}
5 changes: 5 additions & 0 deletions cmd/auth/testdata/.databrickscfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[profile-1]
host = https://www.host1.com

[profile-2]
host = https://www.host2.com
3 changes: 1 addition & 2 deletions libs/auth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ func (a *PersistentAuth) Load(ctx context.Context) (*oauth2.Token, error) {
return refreshed, nil
}

func (a *PersistentAuth) ProfileName() string {
// TODO: get profile name from interactive input
func (a *PersistentAuth) DefaultProfileName() string {
if a.AccountID != "" {
return fmt.Sprintf("ACCOUNT-%s", a.AccountID)
}
Expand Down
Loading