Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix host resolution order in auth login #1370

Merged
merged 25 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions cmd/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ package auth

import (
"context"
"fmt"
"strings"

"github.com/databricks/cli/libs/auth"
"github.com/databricks/cli/libs/cmdio"
"github.com/google/uuid"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -34,25 +37,36 @@ GCP: https://docs.gcp.databricks.com/dev-tools/auth/index.html`,
}

func promptForHost(ctx context.Context) (string, error) {
if !cmdio.IsInTTY(ctx) {
return "", fmt.Errorf("the command is being run in a non-interactive environment, please specify a host using --host")
}

prompt := cmdio.Prompt(ctx)
prompt.Label = "Databricks Host (e.g. https://<databricks-instance>.cloud.databricks.com)"
// Validate?
host, err := prompt.Run()
if err != nil {
return "", err
prompt.Label = "Databricks host"
prompt.Validate = func(host string) error {
if !strings.HasPrefix(host, "https://") {
return fmt.Errorf("host URL must have a https:// prefix")
}
return nil
}
return host, nil
return prompt.Run()
}

func promptForAccountID(ctx context.Context) (string, error) {
if !cmdio.IsInTTY(ctx) {
return "", fmt.Errorf("the command is being run in a non-interactive environment, please specify an account ID using --account-id")
}

prompt := cmdio.Prompt(ctx)
prompt.Label = "Databricks Account ID"
prompt.Label = "Databricks account id"
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
prompt.Default = ""
prompt.AllowEdit = true
// Validate?
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
accountId, err := prompt.Run()
if err != nil {
return "", err
prompt.Validate = func(accountID string) error {
_, err := uuid.Parse(accountID)
if err != nil {
return fmt.Errorf("account ID must be a valid UUID: %w", err)
}
return nil
}
return accountId, nil
return prompt.Run()
}
92 changes: 61 additions & 31 deletions cmd/auth/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,37 @@ import (
"github.com/spf13/cobra"
)

func configureHost(ctx context.Context, persistentAuth *auth.PersistentAuth, args []string, argIndex int) error {
if len(args) > argIndex {
persistentAuth.Host = args[argIndex]
return nil
func promptForProfile(ctx context.Context, dv string) (string, error) {
if !cmdio.IsInTTY(ctx) {
return "", fmt.Errorf("the command is being run in a non-interactive environment, please specify a profile using --profile")
}

host, err := promptForHost(ctx)
prompt := cmdio.Prompt(ctx)
prompt.Label = "Databricks profile name"
prompt.Default = dv
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
prompt.AllowEdit = true
return prompt.Run()
}

func getHostFromProfile(ctx context.Context, profileName string) (string, error) {
_, profiles, err := databrickscfg.LoadProfiles(ctx, func(p databrickscfg.Profile) bool {
return p.Name == profileName
})

// Tolerate ErrNoConfiguration here, as we will write out a configuration file
// as part of the login flow.
if err != nil && errors.Is(err, databrickscfg.ErrNoConfiguration) {
return "", nil
}
if err != nil {
return err
return "", err
}
persistentAuth.Host = host
return nil

// Return host from profile
if len(profiles) > 0 && profiles[0].Host != "" {
return profiles[0].Host, nil
}
return "", nil
}

const minimalDbConnectVersion = "13.1"
Expand Down Expand Up @@ -91,23 +110,18 @@ depends on the existing profiles you have set in your configuration file

cmd.RunE = func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
profileName := cmd.Flag("profile").Value.String()

var profileName string
profileFlag := cmd.Flag("profile")
if profileFlag != nil && profileFlag.Value.String() != "" {
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
profileName = profileFlag.Value.String()
} else if cmdio.IsInTTY(ctx) {
prompt := cmdio.Prompt(ctx)
prompt.Label = "Databricks Profile Name"
prompt.Default = persistentAuth.ProfileName()
prompt.AllowEdit = true
profile, err := prompt.Run()
// If the user has not specified a profile name, prompt for one.
if profileName == "" {
var err error
profileName, err = promptForProfile(ctx, persistentAuth.DefaultProfileName())
if err != nil {
return err
}
profileName = profile
}

// Set the host based on the provided arguments and flags.
err := setHost(ctx, profileName, persistentAuth, args)
if err != nil {
return err
Expand Down Expand Up @@ -172,21 +186,37 @@ depends on the existing profiles you have set in your configuration file
return cmd
}

// Sets the host in the persistentAuth object based on the provided arguments and flags.
// Follows the following precedence:
// 1. [HOST] (first positional argument) or --host flag. Error if both are specified.
// 2. Profile host, if available.
// 3. Prompt the user for the host.
func setHost(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth, args []string) error {
// If the chosen profile has a hostname and the user hasn't specified a host, infer the host from the profile.
_, profiles, err := databrickscfg.LoadProfiles(ctx, func(p databrickscfg.Profile) bool {
return p.Name == profileName
})
// Tolerate ErrNoConfiguration here, as we will write out a configuration as part of the login flow.
if err != nil && !errors.Is(err, databrickscfg.ErrNoConfiguration) {
return err
// If both [HOST] and --host are provided, return an error.
if len(args) > 0 && persistentAuth.Host != "" {
return fmt.Errorf("please only provide a host as an argument or a flag, not both")
}

// If [HOST] is provided, set the host to the provided positional argument.
if len(args) > 0 && persistentAuth.Host == "" {
persistentAuth.Host = args[0]
}
if persistentAuth.Host == "" {
if len(profiles) > 0 && profiles[0].Host != "" {
persistentAuth.Host = profiles[0].Host
} else {
configureHost(ctx, persistentAuth, args, 0)

// If neither [HOST] nor --host are provided, and the profile has a host, use it.
// Otherwise, prompt the user for a host.
if len(args) == 0 && persistentAuth.Host == "" {
hostName, err := getHostFromProfile(ctx, profileName)
if err != nil {
return err
}
if hostName == "" {
var err error
hostName, err = promptForHost(ctx)
if err != nil {
return err
}
}
persistentAuth.Host = hostName
}
return nil
}
35 changes: 35 additions & 0 deletions cmd/auth/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,38 @@ func TestSetHostDoesNotFailWithNoDatabrickscfg(t *testing.T) {
err := setHost(ctx, "foo", &auth.PersistentAuth{Host: "test"}, []string{})
assert.NoError(t, err)
}

func TestSetHost(t *testing.T) {
ctx := context.Background()
var persistentAuth auth.PersistentAuth
t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/.databrickscfg")

// Test error when both flag and argument are provided
persistentAuth.Host = "val from --host"
err := setHost(ctx, "profile-1", &persistentAuth, []string{"val from [HOST]"})
assert.EqualError(t, err, "please only provide a host as an argument or a flag, not both")

// Test setting host from flag
persistentAuth.Host = "val from --host"
err = setHost(ctx, "profile-1", &persistentAuth, []string{})
assert.NoError(t, err)
assert.Equal(t, "val from --host", persistentAuth.Host)

// Test setting host from argument
persistentAuth.Host = ""
err = setHost(ctx, "profile-1", &persistentAuth, []string{"val from [HOST]"})
assert.NoError(t, err)
assert.Equal(t, "val from [HOST]", persistentAuth.Host)

// Test setting host from profile
persistentAuth.Host = ""
err = setHost(ctx, "profile-1", &persistentAuth, []string{})
assert.NoError(t, err)
assert.Equal(t, "https://www.host1.com", persistentAuth.Host)

// Test setting host from profile
persistentAuth.Host = ""
err = setHost(ctx, "profile-2", &persistentAuth, []string{})
assert.NoError(t, err)
assert.Equal(t, "https://www.host2.com", persistentAuth.Host)
}
5 changes: 5 additions & 0 deletions cmd/auth/testdata/.databrickscfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[profile-1]
host = https://www.host1.com

[profile-2]
host = https://www.host2.com
6 changes: 4 additions & 2 deletions libs/auth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,13 @@ func (a *PersistentAuth) Load(ctx context.Context) (*oauth2.Token, error) {
return refreshed, nil
}

func (a *PersistentAuth) ProfileName() string {
// TODO: get profile name from interactive input
func (a *PersistentAuth) DefaultProfileName() string {
if a.AccountID != "" {
return fmt.Sprintf("ACCOUNT-%s", a.AccountID)
}
if a.Host == "" {
return "DEFAULT"
}
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
host := strings.TrimPrefix(a.Host, "https://")
split := strings.Split(host, ".")
return split[0]
Expand Down
Loading