Skip to content

Commit

Permalink
Add method to get and cache dataplane data
Browse files Browse the repository at this point in the history
  • Loading branch information
hectorcast-db committed Apr 11, 2024
1 parent 56ff775 commit ffaab5d
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 3 deletions.
8 changes: 5 additions & 3 deletions httpclient/api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,15 @@ func NewApiClient(cfg ClientConfig) *ApiClient {
Timeout: 0,
Transport: transport,
},
dataPlaneCache: make(dataPlaneCache),
}
}

type ApiClient struct {
config ClientConfig
rateLimiter *rate.Limiter
httpClient *http.Client
config ClientConfig
rateLimiter *rate.Limiter
httpClient *http.Client
dataPlaneCache dataPlaneCache
}

type DoOption struct {
Expand Down
112 changes: 112 additions & 0 deletions httpclient/dataplane.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package httpclient

import (
"fmt"
"time"

"github.com/databricks/databricks-sdk-go/credentials"
"golang.org/x/oauth2"
)

type DataPlaneInfoKey struct {
ServiceName string
Path string
}

// TODO: Replace by the DataPlaneInfo generated by SDK Generator in the package oauth2.
// This will be generated once the first service uses the annotation for the DataPlaneInfo.
type DataPlaneInfo struct {
EndpointUrl string
AuthorizationDetails string
}

type DataPlaneData struct {
DataPlaneInfo *DataPlaneInfo
Token *credentials.OAuthToken
ExpirationTime *time.Time
InfoRefresher func() (*DataPlaneInfo, error)
}

func (d *DataPlaneData) RefreshInfo() error {
newInfo, err := d.InfoRefresher()
if err != nil {
return err
}
d.DataPlaneInfo = newInfo
return nil
}

func (d *DataPlaneData) Expired() bool {
return d.Token == nil || d.ExpirationTime == nil || time.Now().After(*d.ExpirationTime)
}

func (d *DataPlaneData) ExpiresIn(duration time.Duration) bool {
return d.Token == nil || d.ExpirationTime == nil || time.Now().Add(duration).After(*d.ExpirationTime)
}

// TODO: Maybe add locks around the map?
type dataPlaneCache map[DataPlaneInfoKey]*DataPlaneData

// Refreshes the DataPlaneInfo
// This is only required if the backend updates the required permissions
// to call an endpoint, or the endpoint_url changes (which should be very rare).
// Calling this method also invalidates the token.
func (c *ApiClient) RefreshInfo(serviceName string,
path string) error {
key := DataPlaneInfoKey{
ServiceName: serviceName,
Path: path,
}
data, ok := c.dataPlaneCache[key]
if !ok {
return fmt.Errorf("data not found for service %s and endpoint %s", serviceName, path)
}
err := data.RefreshInfo()
if err != nil {
return err
}
data.Token = nil
data.ExpirationTime = nil
return nil
}

func (c *ApiClient) GetOAuthTokenForDataPlane(
serviceName string,
path string,
dataPlaneInfoGetter func() (*DataPlaneInfo, error),
controlPlaneTokenProvider func() (*oauth2.Token, error)) (*credentials.OAuthToken, error) {
key := DataPlaneInfoKey{
ServiceName: serviceName,
Path: path,
}
data, ok := c.dataPlaneCache[key]
if !ok {
data = &DataPlaneData{
InfoRefresher: dataPlaneInfoGetter,
}
data.RefreshInfo()
c.dataPlaneCache[key] = data
}
if data.ExpiresIn(2 * time.Minute) {
err := c.refreshToken(data, controlPlaneTokenProvider)
if err != nil {
return nil, err
}
}
return data.Token, nil
}

func (c *ApiClient) refreshToken(data *DataPlaneData, controlPlaneTokenProvider func() (*oauth2.Token, error)) error {
controlPlaneToken, err := controlPlaneTokenProvider()
if err != nil {
return err
}
newToken, err := c.GetOAuthToken(data.DataPlaneInfo.AuthorizationDetails, controlPlaneToken)
if err != nil {
return err
}
data.Token = newToken
expirationTime := time.Now().Add(time.Second * time.Duration(newToken.ExpiresIn))
data.ExpirationTime = &expirationTime
return nil
}
59 changes: 59 additions & 0 deletions httpclient/dataplane_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package httpclient

import (
"encoding/json"
"io"
"net/http"
"strings"
"testing"

"github.com/databricks/databricks-sdk-go/credentials"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)

var mockTokenProvider = func() (*oauth2.Token, error) {
return &oauth2.Token{
AccessToken: "dummy",
}, nil
}

var mockInfoGetter = func() (*DataPlaneInfo, error) {
return &DataPlaneInfo{
EndpointUrl: "endpoint_url",
AuthorizationDetails: "details",
}, nil
}

func TestGetOAuthTokenForDataPlane(t *testing.T) {
tokenResponse := &credentials.OAuthToken{
AccessToken: "dummyDataPlane",
ExpiresIn: 3600,
Scope: "scope",
TokenType: "Bearer",
}
marshalled, err := json.Marshal(tokenResponse)
require.NoError(t, err)
c := NewApiClient(ClientConfig{
Transport: hc(func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader(string(marshalled))),
Request: r,
}, nil
}),
})
token, err := c.GetOAuthTokenForDataPlane("service", "path", mockInfoGetter, mockTokenProvider)
assert.NoError(t, err)
assert.Equal(t, "dummyDataPlane", token.AccessToken)
cachedData, ok := c.dataPlaneCache[DataPlaneInfoKey{
ServiceName: "service",
Path: "path",
}]
assert.True(t, ok)
assert.Equal(t, "details", cachedData.DataPlaneInfo.AuthorizationDetails)
assert.Equal(t, "endpoint_url", cachedData.DataPlaneInfo.EndpointUrl)
assert.Equal(t, "dummyDataPlane", cachedData.Token.AccessToken)
assert.False(t, cachedData.Expired())
}

0 comments on commit ffaab5d

Please sign in to comment.