diff --git a/httpclient/api_client.go b/httpclient/api_client.go index 4a1e9b1da..f2db0f86b 100644 --- a/httpclient/api_client.go +++ b/httpclient/api_client.go @@ -95,13 +95,15 @@ func NewApiClient(cfg ClientConfig) *ApiClient { Timeout: 0, Transport: transport, }, + dataPlaneCache: make(dataPlaneCache), } } type ApiClient struct { - config ClientConfig - rateLimiter *rate.Limiter - httpClient *http.Client + config ClientConfig + rateLimiter *rate.Limiter + httpClient *http.Client + dataPlaneCache dataPlaneCache } type DoOption struct { diff --git a/httpclient/dataplane.go b/httpclient/dataplane.go new file mode 100644 index 000000000..7211975fe --- /dev/null +++ b/httpclient/dataplane.go @@ -0,0 +1,112 @@ +package httpclient + +import ( + "fmt" + "time" + + "github.com/databricks/databricks-sdk-go/credentials" + "golang.org/x/oauth2" +) + +type DataPlaneInfoKey struct { + ServiceName string + Path string +} + +// TODO: Replace by the DataPlaneInfo generated by SDK Generator in the package oauth2. +// This will be generated once the first service uses the annotation for the DataPlaneInfo. +type DataPlaneInfo struct { + EndpointUrl string + AuthorizationDetails string +} + +type DataPlaneData struct { + DataPlaneInfo *DataPlaneInfo + Token *credentials.OAuthToken + ExpirationTime *time.Time + InfoRefresher func() (*DataPlaneInfo, error) +} + +func (d *DataPlaneData) RefreshInfo() error { + newInfo, err := d.InfoRefresher() + if err != nil { + return err + } + d.DataPlaneInfo = newInfo + return nil +} + +func (d *DataPlaneData) Expired() bool { + return d.Token == nil || d.ExpirationTime == nil || time.Now().After(*d.ExpirationTime) +} + +func (d *DataPlaneData) ExpiresIn(duration time.Duration) bool { + return d.Token == nil || d.ExpirationTime == nil || time.Now().Add(duration).After(*d.ExpirationTime) +} + +// TODO: Maybe add locks around the map? +type dataPlaneCache map[DataPlaneInfoKey]*DataPlaneData + +// Refreshes the DataPlaneInfo +// This is only required if the backend updates the required permissions +// to call an endpoint, or the endpoint_url changes (which should be very rare). +// Calling this method also invalidates the token. +func (c *ApiClient) RefreshInfo(serviceName string, + path string) error { + key := DataPlaneInfoKey{ + ServiceName: serviceName, + Path: path, + } + data, ok := c.dataPlaneCache[key] + if !ok { + return fmt.Errorf("data not found for service %s and endpoint %s", serviceName, path) + } + err := data.RefreshInfo() + if err != nil { + return err + } + data.Token = nil + data.ExpirationTime = nil + return nil +} + +func (c *ApiClient) GetOAuthTokenForDataPlane( + serviceName string, + path string, + dataPlaneInfoGetter func() (*DataPlaneInfo, error), + controlPlaneTokenProvider func() (*oauth2.Token, error)) (*credentials.OAuthToken, error) { + key := DataPlaneInfoKey{ + ServiceName: serviceName, + Path: path, + } + data, ok := c.dataPlaneCache[key] + if !ok { + data = &DataPlaneData{ + InfoRefresher: dataPlaneInfoGetter, + } + data.RefreshInfo() + c.dataPlaneCache[key] = data + } + if data.ExpiresIn(2 * time.Minute) { + err := c.refreshToken(data, controlPlaneTokenProvider) + if err != nil { + return nil, err + } + } + return data.Token, nil +} + +func (c *ApiClient) refreshToken(data *DataPlaneData, controlPlaneTokenProvider func() (*oauth2.Token, error)) error { + controlPlaneToken, err := controlPlaneTokenProvider() + if err != nil { + return err + } + newToken, err := c.GetOAuthToken(data.DataPlaneInfo.AuthorizationDetails, controlPlaneToken) + if err != nil { + return err + } + data.Token = newToken + expirationTime := time.Now().Add(time.Second * time.Duration(newToken.ExpiresIn)) + data.ExpirationTime = &expirationTime + return nil +} diff --git a/httpclient/dataplane_test.go b/httpclient/dataplane_test.go new file mode 100644 index 000000000..0a4ccae50 --- /dev/null +++ b/httpclient/dataplane_test.go @@ -0,0 +1,59 @@ +package httpclient + +import ( + "encoding/json" + "io" + "net/http" + "strings" + "testing" + + "github.com/databricks/databricks-sdk-go/credentials" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +var mockTokenProvider = func() (*oauth2.Token, error) { + return &oauth2.Token{ + AccessToken: "dummy", + }, nil +} + +var mockInfoGetter = func() (*DataPlaneInfo, error) { + return &DataPlaneInfo{ + EndpointUrl: "endpoint_url", + AuthorizationDetails: "details", + }, nil +} + +func TestGetOAuthTokenForDataPlane(t *testing.T) { + tokenResponse := &credentials.OAuthToken{ + AccessToken: "dummyDataPlane", + ExpiresIn: 3600, + Scope: "scope", + TokenType: "Bearer", + } + marshalled, err := json.Marshal(tokenResponse) + require.NoError(t, err) + c := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(string(marshalled))), + Request: r, + }, nil + }), + }) + token, err := c.GetOAuthTokenForDataPlane("service", "path", mockInfoGetter, mockTokenProvider) + assert.NoError(t, err) + assert.Equal(t, "dummyDataPlane", token.AccessToken) + cachedData, ok := c.dataPlaneCache[DataPlaneInfoKey{ + ServiceName: "service", + Path: "path", + }] + assert.True(t, ok) + assert.Equal(t, "details", cachedData.DataPlaneInfo.AuthorizationDetails) + assert.Equal(t, "endpoint_url", cachedData.DataPlaneInfo.EndpointUrl) + assert.Equal(t, "dummyDataPlane", cachedData.Token.AccessToken) + assert.False(t, cachedData.Expired()) +}