From f9675ab8eaedd97a4fa6e843508fcb7c5f8b2670 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Thu, 12 Sep 2024 10:49:27 +0200 Subject: [PATCH 1/2] hackathon device code flow --- cmd/auth/login.go | 9 +++++- libs/auth/oauth.go | 67 +++++++++++++++++++++++++++++++++++------ libs/auth/oauth_test.go | 35 +++++++++++++++++++++ 3 files changed, 100 insertions(+), 11 deletions(-) diff --git a/cmd/auth/login.go b/cmd/auth/login.go index 79b7954680..f9ba41218a 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -84,10 +84,13 @@ depends on the existing profiles you have set in your configuration file var loginTimeout time.Duration var configureCluster bool + var deviceCode bool cmd.Flags().DurationVar(&loginTimeout, "timeout", defaultTimeout, "Timeout for completing login challenge in the browser") cmd.Flags().BoolVar(&configureCluster, "configure-cluster", false, "Prompts to configure cluster") + cmd.Flags().BoolVar(&deviceCode, "device-code", false, + "Use device code flow for authentication") cmd.RunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() @@ -120,7 +123,11 @@ depends on the existing profiles you have set in your configuration file ctx, cancel := context.WithTimeout(ctx, loginTimeout) defer cancel() - err = persistentAuth.Challenge(ctx) + if deviceCode { + err = persistentAuth.DeviceCode(ctx) + } else { + err = persistentAuth.Challenge(ctx) + } if err != nil { return err } diff --git a/libs/auth/oauth.go b/libs/auth/oauth.go index 7c1cb95768..e6b9e24128 100644 --- a/libs/auth/oauth.go +++ b/libs/auth/oauth.go @@ -55,9 +55,16 @@ type PersistentAuth struct { Host string AccountID string - http *httpclient.ApiClient - cache cache.TokenCache - ln net.Listener + // The client used when making requests to Databricks OAuth endpoints. + http *httpclient.ApiClient + + // A token cache for OAuth access & refresh tokens. + cache cache.TokenCache + + // A listener used to receive the OAuth callback. Not used for device-code flow. + ln net.Listener + + // A function to open a URL in the user's browser. Not used for device-code flow. browser func(string) error } @@ -113,11 +120,41 @@ func (a *PersistentAuth) ProfileName() string { return split[0] } +func (a *PersistentAuth) DeviceCode(ctx context.Context) error { + err := a.init(ctx) + if err != nil { + return fmt.Errorf("init: %w", err) + } + cfg, err := a.oauth2Config(ctx) + if err != nil { + return err + } + ctx = a.http.InContextForOAuth2(ctx) + deviceAuthResp, err := cfg.DeviceAuth(ctx) + if err != nil { + return fmt.Errorf("error initiating device code flow: %w", err) + } + fmt.Printf("To authenticate, please visit %s and enter the code %s\n", deviceAuthResp.VerificationURI, deviceAuthResp.UserCode) + token, err := cfg.DeviceAccessToken(ctx, deviceAuthResp) + if err != nil { + return fmt.Errorf("error retrieving token: %w", err) + } + err = a.cache.Store(a.key(), token) + if err != nil { + return fmt.Errorf("store: %w", err) + } + return nil +} + func (a *PersistentAuth) Challenge(ctx context.Context) error { err := a.init(ctx) if err != nil { return fmt.Errorf("init: %w", err) } + err = a.initU2M(ctx) + if err != nil { + return fmt.Errorf("init: %w", err) + } cfg, err := a.oauth2Config(ctx) if err != nil { return err @@ -143,6 +180,8 @@ func (a *PersistentAuth) Challenge(ctx context.Context) error { return nil } +// init validates that the host and account id are set and initializes the http client and token cache. +// It should be called before any other method on PersistentAuth. func (a *PersistentAuth) init(ctx context.Context) error { if a.Host == "" && a.AccountID == "" { return ErrFetchCredentials @@ -153,6 +192,11 @@ func (a *PersistentAuth) init(ctx context.Context) error { if a.cache == nil { a.cache = cache.GetTokenCache(ctx) } + return nil +} + +// initU2M initializes the listener for the user-to-machine flow. It does not need to be called for device-code flow. +func (a *PersistentAuth) initU2M(ctx context.Context) error { if a.browser == nil { a.browser = browser.OpenURL } @@ -186,8 +230,9 @@ func (a *PersistentAuth) oidcEndpoints(ctx context.Context) (*oauthAuthorization prefix := a.key() if a.AccountID != "" { return &oauthAuthorizationServer{ - AuthorizationEndpoint: fmt.Sprintf("%s/v1/authorize", prefix), - TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix), + AuthorizationEndpoint: fmt.Sprintf("%s/v1/authorize", prefix), + TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix), + DeviceAuthorizationEndpoint: fmt.Sprintf("%s/v1/device_authorization", prefix), }, nil } var oauthEndpoints oauthAuthorizationServer @@ -219,9 +264,10 @@ func (a *PersistentAuth) oauth2Config(ctx context.Context) (*oauth2.Config, erro return &oauth2.Config{ ClientID: appClientID, Endpoint: oauth2.Endpoint{ - AuthURL: endpoints.AuthorizationEndpoint, - TokenURL: endpoints.TokenEndpoint, - AuthStyle: oauth2.AuthStyleInParams, + AuthURL: endpoints.AuthorizationEndpoint, + TokenURL: endpoints.TokenEndpoint, + DeviceAuthURL: endpoints.DeviceAuthorizationEndpoint, + AuthStyle: oauth2.AuthStyleInParams, }, RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr), Scopes: scopes, @@ -260,6 +306,7 @@ func (a *PersistentAuth) randomString(size int) string { } type oauthAuthorizationServer struct { - AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize - TokenEndpoint string `json:"token_endpoint"` // ../v1/token + AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize + TokenEndpoint string `json:"token_endpoint"` // ../v1/token + DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint"` // ../v1/device_authorization } diff --git a/libs/auth/oauth_test.go b/libs/auth/oauth_test.go index ea6a8061e6..d52162effb 100644 --- a/libs/auth/oauth_test.go +++ b/libs/auth/oauth_test.go @@ -228,3 +228,38 @@ func TestChallengeFailed(t *testing.T) { assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request") }) } + +func TestDeviceCode_Account(t *testing.T) { + qa.HTTPFixtures{ + { + Method: "POST", + Resource: "/oidc/accounts/xyz/v1/device_authorization", + Response: `{"device_code":"abc","user_code":"def","verification_uri":"ghi"}`, + }, + { + Method: "POST", + Resource: "/oidc/accounts/xyz/v1/token", + Response: `access_token=jkl&refresh_token=mnop`, + }, + }.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) { + ctx = useInsecureOAuthHttpClientForTests(ctx) + tokenStored := false + p := &PersistentAuth{ + Host: c.Config.Host, + AccountID: "xyz", + cache: &tokenCacheMock{ + store: func(key string, tok *oauth2.Token) error { + assert.Equal(t, fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host), key) + assert.Equal(t, "mnop", tok.RefreshToken) + tokenStored = true + return nil + }, + }, + } + defer p.Close() + + err := p.DeviceCode(ctx) + assert.NoError(t, err) + assert.True(t, tokenStored) + }) +} From 517e1b3310e6536bb305291837dad7c48f2f884d Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Thu, 12 Sep 2024 10:51:36 +0200 Subject: [PATCH 2/2] handle edge case --- libs/auth/oauth.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/libs/auth/oauth.go b/libs/auth/oauth.go index e6b9e24128..0c250d8c09 100644 --- a/libs/auth/oauth.go +++ b/libs/auth/oauth.go @@ -46,9 +46,10 @@ const ( ) var ( // Databricks SDK API: `databricks OAuth is not` will be checked for presence - ErrOAuthNotSupported = errors.New("databricks OAuth is not supported for this host") - ErrNotConfigured = errors.New("databricks OAuth is not configured for this host") - ErrFetchCredentials = errors.New("cannot fetch credentials") + ErrOAuthNotSupported = errors.New("databricks OAuth is not supported for this host") + ErrNotConfigured = errors.New("databricks OAuth is not configured for this host") + ErrFetchCredentials = errors.New("cannot fetch credentials") + ErrDeviceCodeNotSupported = errors.New("device code flow is not supported for this host") ) type PersistentAuth struct { @@ -129,6 +130,9 @@ func (a *PersistentAuth) DeviceCode(ctx context.Context) error { if err != nil { return err } + if cfg.Endpoint.DeviceAuthURL == "" { + return ErrDeviceCodeNotSupported + } ctx = a.http.InContextForOAuth2(ctx) deviceAuthResp, err := cfg.DeviceAuth(ctx) if err != nil {