Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a method to generate OAuth tokens #886

Merged
merged 19 commits into from
May 17, 2024
Merged
6 changes: 4 additions & 2 deletions config/api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/databricks/databricks-sdk-go/apierr"
"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/useragent"
)
Expand Down Expand Up @@ -103,6 +104,7 @@ func (noopLoader) Configure(cfg *Config) error { return nil }
type noopAuth struct{}

func (noopAuth) Name() string { return "noop" }
func (noopAuth) Configure(context.Context, *Config) (func(*http.Request) error, error) {
return func(r *http.Request) error { return nil }, nil
func (noopAuth) Configure(context.Context, *Config) (credentials.CredentialsProvider, error) {
visitor := func(r *http.Request) error { return nil }
return credentials.NewCredentialsProvider(visitor), nil
}
5 changes: 3 additions & 2 deletions config/auth_azure_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"golang.org/x/oauth2"

"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/logger"
)

Expand Down Expand Up @@ -53,7 +54,7 @@ func (c AzureCliCredentials) getVisitor(ctx context.Context, cfg *Config, inner
return azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken)), nil
}

func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.Request) error, error) {
func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if !cfg.IsAzure() {
return nil, nil
}
Expand Down Expand Up @@ -81,7 +82,7 @@ func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (func(*
return nil, err
}
logger.Infof(ctx, "Using Azure CLI authentication with AAD tokens")
return visitor, nil
return credentials.NewCredentialsProvider(visitor), nil
}

// NewAzureCliTokenSource returns [oauth2.TokenSource] for a passwordless authentication via Azure CLI (`az login`)
Expand Down
12 changes: 6 additions & 6 deletions config/auth_azure_cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestAzureCliCredentials_Valid(t *testing.T) {
assert.NoError(t, err)

r := &http.Request{Header: http.Header{}}
err = visitor(r)
err = visitor.SetHeaders(r)
assert.NoError(t, err)

assert.Equal(t, "Bearer ...", r.Header.Get("Authorization"))
Expand All @@ -88,7 +88,7 @@ func TestAzureCliCredentials_ReuseTokens(t *testing.T) {
assert.NoError(t, err)

r := &http.Request{Header: http.Header{}}
err = visitor(r)
err = visitor.SetHeaders(r)
assert.NoError(t, err)

// We verify the headers in the test above.
Expand All @@ -107,7 +107,7 @@ func TestAzureCliCredentials_ValidNoManagementAccess(t *testing.T) {
assert.NoError(t, err)

r := &http.Request{Header: http.Header{}}
err = visitor(r)
err = visitor.SetHeaders(r)
assert.NoError(t, err)

assert.Equal(t, "Bearer ...", r.Header.Get("Authorization"))
Expand All @@ -123,7 +123,7 @@ func TestAzureCliCredentials_ValidWithAzureResourceId(t *testing.T) {
assert.NoError(t, err)

r := &http.Request{Header: http.Header{}}
err = visitor(r)
err = visitor.SetHeaders(r)
assert.NoError(t, err)

assert.Equal(t, "Bearer ...", r.Header.Get("Authorization"))
Expand All @@ -138,7 +138,7 @@ func TestAzureCliCredentials_Fallback(t *testing.T) {
assert.NoError(t, err)

r := &http.Request{Header: http.Header{}}
err = visitor(r)
err = visitor.SetHeaders(r)
assert.NoError(t, err)

assert.Equal(t, "Bearer ...", r.Header.Get("Authorization"))
Expand All @@ -155,7 +155,7 @@ func TestAzureCliCredentials_AlwaysExpired(t *testing.T) {
assert.NoError(t, err)

r := &http.Request{Header: http.Header{}}
err = visitor(r)
err = visitor.SetHeaders(r)

assert.NoError(t, err)
}
Expand Down
7 changes: 4 additions & 3 deletions config/auth_azure_client_secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package config
import (
"context"
"fmt"
"net/http"
"net/url"

"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"

"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/logger"
)

Expand All @@ -35,7 +35,7 @@ func (c AzureClientSecretCredentials) tokenSourceFor(
// as we cannot create AKV backed secret scopes when authenticated as SP.
// If we are authenticated as SP and wish to create one we want to fail early.
// Also see https://github.com/databricks/terraform-provider-databricks/issues/1490.
func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.Request) error, error) {
func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if cfg.AzureClientID == "" || cfg.AzureClientSecret == "" || cfg.AzureTenantID == "" {
return nil, nil
}
Expand All @@ -52,5 +52,6 @@ func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config
managementEndpoint := env.AzureServiceManagementEndpoint()
inner := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, aadEndpoint, env.AzureApplicationID))
management := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, aadEndpoint, managementEndpoint))
return azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken)), nil
visitor := azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken))
return credentials.NewCredentialsProvider(visitor), nil
hectorcast-db marked this conversation as resolved.
Show resolved Hide resolved
}
6 changes: 4 additions & 2 deletions config/auth_azure_msi.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"time"

"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
"golang.org/x/oauth2"
Expand All @@ -30,7 +31,7 @@ func (c AzureMsiCredentials) Name() string {
return "azure-msi"
}

func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.Request) error, error) {
func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if !cfg.IsAzure() || !cfg.AzureUseMSI || (cfg.AzureResourceID == "" && !cfg.IsAccountClient()) {
return nil, nil
}
Expand All @@ -44,7 +45,8 @@ func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (func(*
logger.Debugf(ctx, "Generating AAD token via Azure MSI")
inner := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, "", env.AzureApplicationID))
management := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, "", env.AzureServiceManagementEndpoint()))
return azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken)), nil
visitor := azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken))
return credentials.NewCredentialsProvider(visitor), nil
}

// implementing azureHostResolver for ensureWorkspaceUrl to work
Expand Down
9 changes: 6 additions & 3 deletions config/auth_basic.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"encoding/base64"
"fmt"
"net/http"

"github.com/databricks/databricks-sdk-go/credentials"
)

type BasicCredentials struct {
Expand All @@ -14,14 +16,15 @@ func (c BasicCredentials) Name() string {
return "basic"
}

func (c BasicCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.Request) error, error) {
func (c BasicCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if cfg.Username == "" || cfg.Password == "" || cfg.Host == "" {
return nil, nil
}
tokenUnB64 := fmt.Sprintf("%s:%s", cfg.Username, cfg.Password)
b64 := base64.StdEncoding.EncodeToString([]byte(tokenUnB64))
return func(r *http.Request) error {
visitor := func(r *http.Request) error {
r.Header.Set("Authorization", fmt.Sprintf("Basic %s", b64))
return nil
}, nil
}
return credentials.NewCredentialsProvider(visitor), nil
}
7 changes: 4 additions & 3 deletions config/auth_databricks_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"

"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/logger"
"golang.org/x/oauth2"
)
Expand All @@ -22,7 +22,7 @@ func (c DatabricksCliCredentials) Name() string {
return "databricks-cli"
}

func (c DatabricksCliCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.Request) error, error) {
func (c DatabricksCliCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if cfg.Host == "" {
return nil, nil
}
Expand Down Expand Up @@ -54,7 +54,8 @@ func (c DatabricksCliCredentials) Configure(ctx context.Context, cfg *Config) (f
return nil, err
}
logger.Debugf(ctx, "Using Databricks CLI authentication with Databricks OAuth tokens")
return refreshableVisitor(ts), nil
visitor := refreshableVisitor(ts)
return credentials.NewOAuthCredentialsProvider(visitor, ts.Token), nil
}

var errLegacyDatabricksCli = errors.New("legacy Databricks CLI detected")
Expand Down
12 changes: 6 additions & 6 deletions config/auth_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ import (
"context"
"errors"
"fmt"
"net/http"

"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/logger"
)

var (
authProviders = []CredentialsProvider{
authProviders = []CredentialsStrategy{
PatCredentials{},
BasicCredentials{},
M2mCredentials{},
Expand Down Expand Up @@ -45,23 +45,23 @@ var errorMessage = fmt.Sprintf("cannot configure default credentials, please che
// ErrCannotConfigureAuth (experimental) is returned when no auth is configured
var ErrCannotConfigureAuth = errors.New(errorMessage)

func (c *DefaultCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.Request) error, error) {
func (c *DefaultCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
for _, p := range authProviders {
if cfg.AuthType != "" && p.Name() != cfg.AuthType {
// ignore other auth types if one is explicitly enforced
logger.Infof(ctx, "Ignoring %s auth, because %s is preferred", p.Name(), cfg.AuthType)
continue
}
logger.Tracef(ctx, "Attempting to configure auth: %s", p.Name())
visitor, err := p.Configure(ctx, cfg)
credentialsProvider, err := p.Configure(ctx, cfg)
if err != nil {
return nil, fmt.Errorf("%s: %w", p.Name(), err)
}
if visitor == nil {
if credentialsProvider == nil {
continue
}
c.name = p.Name()
return visitor, nil
return credentialsProvider, nil
}
return nil, ErrCannotConfigureAuth
}
7 changes: 4 additions & 3 deletions config/auth_gcp_google_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import (
"context"
"fmt"
"io/ioutil"
"net/http"
"os"

"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/logger"
"golang.org/x/oauth2/google"
"google.golang.org/api/idtoken"
Expand All @@ -20,7 +20,7 @@ func (c GoogleCredentials) Name() string {
return "google-credentials"
}

func (c GoogleCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.Request) error, error) {
func (c GoogleCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if cfg.GoogleCredentials == "" || !cfg.IsGcp() {
return nil, nil
}
Expand All @@ -42,7 +42,8 @@ func (c GoogleCredentials) Configure(ctx context.Context, cfg *Config) (func(*ht
return nil, fmt.Errorf("could not obtain OAuth2 token from JSON: %w", err)
}
logger.Infof(ctx, "Using Google Credentials")
return serviceToServiceVisitor(inner, creds.TokenSource, "X-Databricks-GCP-SA-Access-Token"), nil
visitor := serviceToServiceVisitor(inner, creds.TokenSource, "X-Databricks-GCP-SA-Access-Token")
return credentials.NewCredentialsProvider(visitor), nil
}

// Reads credentials as JSON. Credentials can be either a path to JSON file,
Expand Down
10 changes: 6 additions & 4 deletions config/auth_gcp_google_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package config
import (
"context"
"fmt"
"net/http"

"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/logger"
"golang.org/x/oauth2"
"google.golang.org/api/impersonate"
Expand All @@ -20,7 +20,7 @@ func (c GoogleDefaultCredentials) Name() string {
return "google-id"
}

func (c GoogleDefaultCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.Request) error, error) {
func (c GoogleDefaultCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if cfg.GoogleServiceAccount == "" || !cfg.IsGcp() {
return nil, nil
}
Expand All @@ -30,7 +30,8 @@ func (c GoogleDefaultCredentials) Configure(ctx context.Context, cfg *Config) (f
}
if !cfg.IsAccountClient() {
logger.Infof(ctx, "Using Google Default Application Credentials for Workspace")
return refreshableVisitor(inner), nil
visitor := refreshableVisitor(inner)
return credentials.NewCredentialsProvider(visitor), nil
}
// source for generateAccessToken
platform, err := impersonate.CredentialsTokenSource(ctx, impersonate.CredentialsConfig{
Expand All @@ -44,7 +45,8 @@ func (c GoogleDefaultCredentials) Configure(ctx context.Context, cfg *Config) (f
return nil, err
}
logger.Infof(ctx, "Using Google Default Application Credentials for Accounts API")
return serviceToServiceVisitor(inner, platform, "X-Databricks-GCP-SA-Access-Token"), nil
visitor := serviceToServiceVisitor(inner, platform, "X-Databricks-GCP-SA-Access-Token")
return credentials.NewCredentialsProvider(visitor), nil
hectorcast-db marked this conversation as resolved.
Show resolved Hide resolved
}

func (c GoogleDefaultCredentials) idTokenSource(ctx context.Context, host, serviceAccount string,
Expand Down
7 changes: 4 additions & 3 deletions config/auth_m2m.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ import (
"context"
"errors"
"fmt"
"net/http"

"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"

"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
)
Expand All @@ -22,7 +22,7 @@ func (c M2mCredentials) Name() string {
return "oauth-m2m"
}

func (c M2mCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.Request) error, error) {
func (c M2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor suggestion: I would return the most specific type that you can. If someone uses this directly and wants to be able to call the Token method, they'll need to match on the type first. We already know that this is an OAuthCredentialsProvider.

Suggested change
func (c M2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
func (c M2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials.OAuthCredentialsProvider, error) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then M2mCredentials does not fulfill the CredentialsStrategy interface. Not sure if there is a way around it.

if cfg.ClientID == "" || cfg.ClientSecret == "" {
return nil, nil
}
Expand All @@ -38,7 +38,8 @@ func (c M2mCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.
TokenURL: endpoints.TokenEndpoint,
Scopes: []string{"all-apis"},
}).TokenSource(ctx)
return refreshableVisitor(ts), nil
visitor := refreshableVisitor(ts)
return credentials.NewCredentialsProvider(visitor), nil
}

func oidcEndpoints(ctx context.Context, cfg *Config) (*oauthAuthorizationServer, error) {
Expand Down
6 changes: 4 additions & 2 deletions config/auth_metadata_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/url"
"time"

"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
"golang.org/x/oauth2"
Expand Down Expand Up @@ -48,7 +49,7 @@ func (c MetadataServiceCredentials) Name() string {
return "metadata-service"
}

func (c MetadataServiceCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.Request) error, error) {
func (c MetadataServiceCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if cfg.MetadataServiceURL == "" || cfg.Host == "" {
return nil, nil
}
Expand All @@ -72,7 +73,8 @@ func (c MetadataServiceCredentials) Configure(ctx context.Context, cfg *Config)
if response == nil {
return nil, nil
}
return refreshableVisitor(&ms), nil
visitor := refreshableVisitor(&ms)
return credentials.NewCredentialsProvider(visitor), nil
}

type metadataService struct {
Expand Down
Loading
Loading